Differentiating through optimization with the IFT


A startup recently came out that reminded me of an old (embarrassing) report I wrote on differentiating through optimization problems with the implicit function theorem (IFT). If I remember correctly, I was playing with ways to differentiate through massively sparse softmax problems. Unfortunately, I think I gave up after a week. Don’t be me.

Regardless, the IFT is super cool! It lets you differentiate through implicitly defined functions. For example, you want to optimize the parameters of an optimization problem. The issue is that the solver for this problem uses gradient descent and therefore takes many steps. Reverse-mode autodiff would require storing the entire execution trace to compute gradients, which requires memory linear in the number of iterations. The IFT lets you throw all of that away and compute gradients using only the solution and the optimality conditions, which implicitly defines a function.

Solver trajectory vs IFT

A solver wiggles from x0x_0 to the optimum xx^* (red), producing a long trace that reverse-mode AD must store. The IFT (blue, dashed) computes dxdθ\frac{dx^*}{d\theta} directly from the optimality conditions at xx^*, without looking at the solver’s path.

The IFT shows up in a number of places. OptNet used it to backpropagate through quadratic program solvers, letting neural networks learn constraints (e.g. Sudoku from input-output pairs). Influence functions used it for training data attribution (I guess mainly Anthropic). Meta-learning methods like MAML differentiate through inner-loop optimization, and iMAML used the IFT to do so without unrolling. Hyperparameter optimization and variational inference also fit naturally into this framework. Note: You can also derive autodiff itself from the IFT.

In this post, I’ll walk through the mechanics of applying the IFT to softmax, expressed as an optimization problem. Warning: This post goes a bit deeper into calculations than my previous posts as the IFT is pretty mechanical.

The Implicit Function Theorem

As a starting example, consider the unit circle governed by the relation

F(θ,x)=θ2+x21=0.F(\theta, x) = \theta^2 + x^2 - 1 = 0.

Computing the derivative dxdθ\frac{dx}{d\theta} is not straightforward, as FF fails the vertical line test and we cannot write xx as a function of θ\theta globally. However, we can find local parameterizations: on the upper semicircle (x>0x > 0), x=f1(θ)=1θ2x = f_1(\theta) = \sqrt{1 - \theta^2}, and on the lower semicircle (x<0x < 0), x=f2(θ)=1θ2x = f_2(\theta) = -\sqrt{1 - \theta^2}. These let us compute dxdθ\frac{dx}{d\theta} at particular solution points, though not everywhere (e.g. not at x=0x = 0).

Unit circle illustrating the implicit function theorem

The unit circle F(θ,x)=θ2+x21=0F(\theta, x) = \theta^2 + x^2 - 1 = 0. Near p1p_1, we can locally write x=1θ2x = \sqrt{1 - \theta^2} (green arc), so dxdθ\frac{dx}{d\theta} exists there. At p2=(1,0)p_2 = (1, 0), the circle is vertical, so no local parameterization exists and the derivative is undefined.

The IFT generalizes this. Instead of finding a local parameterization, it guarantees one exists and tells you its derivative. The parameterization is left implicit, hence the name.

Formally, given a system of equations F(θ,x)=0mF(\theta, x) = \mathbf{0}_m and a solution point (θ,x)Rn×Rm(\theta, x) \in \mathbb{R}^n \times \mathbb{R}^m, the IFT says there exists a local solution mapping x(θ)x^*(\theta) if:

  1. We have a solution F(θ,x)=0F(\theta, x) = 0.
  2. Continuous first derivatives FC1F \in \mathcal{C}^1.
  3. The Jacobian dF(θ,x)dx\frac{dF(\theta, x)}{dx} is nonsingular at the solution point.

When these hold, the derivative of the solution mapping is

dx(θ)dθ=[dF(θ,x)dx]1dF(θ,x)dθRm×n.\frac{d x^*(\theta)}{d\theta} = -\left[\frac{dF(\theta, x)}{dx}\right]^{-1} \frac{dF(\theta, x)}{d\theta} \in \mathbb{R}^{m \times n}.

The key point: as long as the conditions hold at a solution point, we can compute this derivative regardless of how we arrived at that solution.

The Softmax Optimization Problem

Let’s apply this to something concrete: softmax. Softmax has a known Jacobian, so we can verify the IFT gives the right answer. The exercise is to express softmax as an optimization problem, then derive its Jacobian using the IFT.

Softmax is defined as follows. Given nn items with independent utilities, where θRn\theta \in \mathbb{R}^n, softmax yields a distribution over items:

zi=softmax(θ)i=exp(θi)jexp(θj),z_i = \mathrm{softmax}(\theta)_i = \frac{\exp(\theta_i)}{\sum_j \exp(\theta_j)},

with z[0,1]nz \in [0,1]^n.

Softmax is also the solution of the following constrained optimization problem:

maximize zθ+H(z)subject to z1n=1zi0,i,\begin{aligned} \textrm{maximize } \quad & z^\top\theta + H(z)\\ \textrm{subject to } \quad & z^\top \mathbf{1}_n = 1\\ & z_i \geq 0, \forall i, \end{aligned}

where H(z)=izilogziH(z) = -\sum_i z_i \log z_i is the entropy.

Our goal is to compute the Jacobian of softmax

dzdθ=dsoftmax(θ)dθ\frac{dz}{d\theta} = \frac{d\,\mathrm{softmax}(\theta)}{d\theta}

using the IFT and the optimization problem above.

Applying the IFT

Applying the IFT consists of four steps:

  1. Find a solution to the optimization problem.
  2. Write down a system of equations derived from the optimality conditions.
  3. Check that the conditions of the IFT hold.
  4. Compute the derivative of the implicit solution mapping with respect to the parameters.

We assume the first step has been done for us, and we have a solution zz to the softmax problem. Most of these manipulations are math homework flavored.

Step 2: The KKT conditions determine the system of equations

In order to apply the IFT, we need a system of equations for which our outputs of interest are solution points. For solutions of optimization problems, the Karush-Kuhn-Tucker (KKT) conditions are a natural choice for defining such a system of equations. Given an optimization problem, the KKT conditions determine a system of equations that the solution must satisfy based on the optimality criteria. They are stationarity (the gradient should be 00 at a local optimum) and feasibility (the constraints of the problem should not be violated).

We will use the KKT conditions of the softmax problem to determine the vector-valued function F:Rn×Rn+1Rn+1F : \mathbb{R}^n \times \mathbb{R}^{n+1} \to \mathbb{R}^{n+1} in the IFT. The optimization problem has both equality and inequality constraints, but for finite θ\theta the softmax solution is always strictly positive:

zi=exp(θi)jexp(θj)>0.z_i = \frac{\exp(\theta_i)}{\sum_j \exp(\theta_j)} > 0.

That means the inequality constraints are inactive at the solution points we care about, so it is cleaner to work with the reduced KKT system containing only the equality constraint. We therefore introduce only the equality multiplier uRu \in \mathbb{R} and write out the Lagrangian:

L(θ,z,u)=zθ+H(z)+u(z1n1).\mathcal{L}(\theta, z, u) = z^\top\theta + H(z) + u(z^\top \mathbf{1}_n - 1).

We therefore have the solution point (θ,z,u)(\theta, z, u), with parameters θ\theta and solution x=(z,u)x = (z, u). We then have the following necessary conditions for a solution (z,u)(z, u), i.e. the KKT conditions:

zL(θ,z,u)=0n(stationarity)z1n1=0(primal feasibility, equality)z0n(primal feasibility, inequality).\begin{aligned} \frac{\partial}{\partial z} \mathcal{L}(\theta, z, u) &= \mathbf{0}_n && \textrm{(stationarity)}\\ z^\top \mathbf{1}_n - 1 &= 0 && \textrm{(primal feasibility, equality)}\\ z &\succeq \mathbf{0}_n && \textrm{(primal feasibility, inequality)}. \end{aligned}

As we only need a system of equations with n+1n+1 equations to determine the n+1n+1 solution variables x=(z,u)x = (z, u), we use the first two conditions: stationarity and primal feasibility (equality).

In full, the system of equations F(θ,z,u)=0F(\theta, z, u) = 0 we choose for the softmax problem is

θlog(z)1n+u1n=0nz1n1=0.\begin{aligned} \theta - \log(z) - \mathbf{1}_n + u\mathbf{1}_n &= \mathbf{0}_n\\ z^\top \mathbf{1}_n - 1 &= 0. \end{aligned}

Step 3: Check that the IFT conditions hold at the solution point

The IFT only applies if the following three conditions to hold, which must be checked on a case-by-case basis for particular solution points. Note that the derivative dxdθ\frac{dx}{d\theta} may still be computed via other means if it exists.

  • F(θ,z,u)=0F(\theta, z, u) = 0,
  • FF has at least continuous first derivatives,
  • detdF(θ,z,u)d(z,u)0\det \frac{dF(\theta, z, u)}{d(z, u)} \ne 0, or equivalently dF(θ,z,u)d(z,u)\frac{dF(\theta, z, u)}{d(z, u)} is full rank.

In the softmax problem, the first condition holds as we have a solution to the optimization problem and FF was chosen using the KKT conditions. The second condition also holds, as FF has continuous first derivatives. All that remains is to check the third condition, that the Jacobian matrix

dF(θ,z,u)d(z,u)\frac{dF(\theta, z, u)}{d(z, u)}

(evaluated at the solution point) is non-singular.

The Jacobian matrix is given by

dF(θ,z,u)d(z,u)=[diag(z)11n1n0].\frac{dF(\theta, z, u)}{d(z, u)} = \begin{bmatrix} -\mathrm{diag}(z)^{-1} & \mathbf{1}_n \\ \mathbf{1}_n^\top & 0 \end{bmatrix}.

The upper-left block is diag(z)1-\mathrm{diag}(z)^{-1} (from the Hessian of HH); its entries blow up if any component zi=0z_i = 0. We saw a similar issue in the unit circle example, where the derivative dxdθ\frac{dx}{d\theta} was undefined when x=0x=0. Luckily, softmax already satisfies zi>0z_i > 0 for finite θ\theta.

Once z0z \succ 0, the block diag(z)1-\mathrm{diag}(z)^{-1} is invertible. The Schur complement of the upper-left block is

01n(diag(z))1n=z1n=1,0 - \mathbf{1}_n^\top\left(-\mathrm{diag}(z)\right)\mathbf{1}_n = z^\top \mathbf{1}_n = 1,

where we used primal feasibility. Therefore the Jacobian of FF is nonsingular. This shows that the conditions of the IFT hold for the solution points that are feasible, optimal, and have strictly positive zz. Recall that the IFT applies at particular solution points, meaning we can pick and choose which points to analyze.

Step 4: Compute dxdθ\frac{dx}{d\theta}

Now that we have a set of solution points where the IFT holds, we can use the IFT to compute dzdθ\frac{dz}{d\theta}. Recall that we have the solution x=(z,u)x = (z, u). The second part of the IFT tells us that we can compute the Jacobian of the solution mapping

dxdθ=dx(θ)dθ=[dF(θ,x)dx]1dF(θ,x)dθ,\frac{dx}{d\theta} = \frac{dx^*(\theta)}{d\theta} = -\left[\frac{dF(\theta, x)}{dx}\right]^{-1}\frac{dF(\theta, x)}{d\theta},

then pick out the relevant components.

The second term dF(θ,x)dθ\frac{dF(\theta, x)}{d\theta} is simple. Since θ\theta only appears in the first vector-valued function of FF, we have

dF(θ,x)dθ=[In×n01×n].\frac{dF(\theta, x)}{d\theta} = \begin{bmatrix} I_{n \times n}\\ \mathbf{0}_{1 \times n} \end{bmatrix}.

Rather than write down the full inverse explicitly, it is simpler to solve the linear system induced by the IFT identity. Differentiating the equations in F(θ,z,u)=0F(\theta, z, u)=0 gives

dθdiag(z)1dz+1ndu=0n1ndz=0.\begin{aligned} d\theta - \mathrm{diag}(z)^{-1}dz + \mathbf{1}_n\,du &= \mathbf{0}_n\\ \mathbf{1}_n^\top dz &= 0. \end{aligned}

Solving the first equation for dzdz yields

dz=diag(z)(dθ+1ndu).dz = \mathrm{diag}(z)\left(d\theta + \mathbf{1}_n\,du\right).

Substituting this into the equality constraint gives

0=1ndz=1ndiag(z)(dθ+1ndu)=zdθ+(z1n)du=zdθ+du,\begin{aligned} 0 &= \mathbf{1}_n^\top dz\\ &= \mathbf{1}_n^\top \mathrm{diag}(z)\left(d\theta + \mathbf{1}_n\,du\right)\\ &= z^\top d\theta + \left(z^\top \mathbf{1}_n\right)du\\ &= z^\top d\theta + du, \end{aligned}

where we again used feasibility: z1n=1z^\top \mathbf{1}_n = 1. Therefore

du=zdθ.du = -z^\top d\theta.

Plugging this back in,

dz=diag(z)(dθ1nzdθ)=(diag(z)zz)dθ.\begin{aligned} dz &= \mathrm{diag}(z)\left(d\theta - \mathbf{1}_n z^\top d\theta\right)\\ &= \left(\mathrm{diag}(z) - zz^\top\right)d\theta. \end{aligned}

Therefore the Jacobian of the solution mapping is

dzdθ=diag(z)zz,\frac{d z}{d\theta} = \mathrm{diag}(z) - zz^\top,

which agrees with the analytic Jacobian.

For this toy example, the special structure of the KKT system lets us solve the linear system in closed form. However, solving the IFT system in general requires computing the inverse Hessian of the Lagrangian, which takes O(n3)O(n^3) time. In practice, we typically avoid materializing the full Jacobian and instead compute JVPs or VJPs through the implicit function, where the cost can be further alleviated with approximate inverse-Hessian-vector-product methods.