Differentiating through an associative parallel scan


State-space models and linear (matrix-valued) RNNs have recently risen greatly in popularity, thanks to efficient hardware implementations. In particular, parallel associative scans allow one to compute reductions of sequences of length TT in O(logT)O(\log T) time on parallel hardware.

Implementing these parallel reductions requires tremendous care: naively, they would require much more memory than a sequential O(T)O(T) scan. In practice, people get around this via checkpointing and bespoke implementations of the forward and backward passes. Manual implementation of backward passes is painful — this is the whole problem automatic differentiation is supposed to solve!

In this blog post, we’re going to work through the backwards pass of an associative scan. This post is inspired by a problem my advisor, Sasha, ran into when implementing state space models.

Parallel scans

Say we have a sequence of inputs, x1,x2,,x5x_1, x_2,\ldots,x_5, which can be vectors, flattened matrices, etc. Given a sequence of partial products, i.e. the cumulative product,

ht=xtht1=xtxt1x1,\begin{equation} h_t = x_t \oplus h_{t-1} = x_t \oplus x_{t-1} \oplus \cdots \oplus x_1, \end{equation}

we want to optimize a loss function L(h1,,h5)L(h_1,\ldots,h_5) via gradient descent.

Therefore, our goal is to efficiently compute the following gradient efficiently:

Lxi=t=iTLhthtxi=t=iTLhththt1hixi.\begin{align} \frac{\partial L}{\partial x_i} &= \sum_{t=i}^T \frac{\partial L}{\partial h_t}\frac{\partial h_t}{\partial x_i}\\ &= \sum_{t=i}^T \frac{\partial L}{\partial h_t}\frac{\partial h_t}{\partial h_{t-1}} \cdots\frac{\partial h_i}{\partial x_i}. \end{align}

Recurrent form

To make things more amenable to writing in a recurrent form, as in equation (1), we can compute the transpose of the gradient:

Lxi=t=iThixihtht1Lht=hixiLhi+hixihi+1hiLhi+1+=hixi(Lhi+hi+1hiLhi+1+)=hixi(Lhi+hi+1hi(Lhi+1+)).\begin{align} \frac{\partial L}{\partial x_i}^\top &= \sum_{t=i}^T \frac{\partial h_i}{\partial x_i}^\top \cdots\frac{\partial h_t}{\partial h_{t-1}}^\top \frac{\partial L}{\partial h_t}^\top\\ &= \frac{\partial h_i}{\partial x_i}^\top \frac{\partial L}{\partial h_i}^\top + \frac{\partial h_i}{\partial x_i} \frac{\partial h_{i+1}}{\partial h_i}^\top \frac{\partial L}{\partial h_{i+1}}^\top + \cdots\\ &= \frac{\partial h_i}{\partial x_i}^\top \left(\frac{\partial L}{\partial h_i}^\top + \frac{\partial h_{i+1}}{\partial h_i}^\top \frac{\partial L}{\partial h_{i+1}}^\top + \cdots\right)\\ &= \frac{\partial h_i}{\partial x_i}^\top \left(\frac{\partial L}{\partial h_i}^\top + \frac{\partial h_{i+1}}{\partial h_i}^\top \left(\frac{\partial L}{\partial h_{i+1}}^\top + \cdots\right)\right). \end{align}

This can be written as a reverse process as follows:

gi=Lhi+hi+1higi+1,\begin{equation} g_i = \frac{\partial L}{\partial h_i}^\top+\frac{\partial h_{i+1}}{\partial h_{i}}^\top g_{i+1}, \end{equation}

and

Lxi=hixigi.\begin{equation} \frac{\partial L}{\partial x_i}^\top = \frac{\partial h_i}{\partial x_i}g_i. \end{equation}

The recurrence in equation (8) can be implemented with a parallel scan, then scaled in equation (9).

Grid form

We present another view. Let’s trek all the way back to equation (3),

Lxi=t=iTLhththt1hixi,\frac{\partial L}{\partial x_i} = \sum_{t=i}^T \frac{\partial L}{\partial h_t}\frac{\partial h_t}{\partial h_{t-1}} \cdots\frac{\partial h_i}{\partial x_i},

and try writing it out for a couple steps:

Lhihixi+Lhi+1hi+1hihixi+Lhi+2hi+1hihi+2hi+1hixi\begin{array}{ccccc} & \frac{\partial L}{\partial h_i} & & & \frac{\partial h_i}{\partial x_i}\\ + & \frac{\partial L}{\partial h_{i+1}} & \frac{\partial h_{i+1}}{\partial h_i} & & \frac{\partial h_i}{\partial x_i}\\ + & \frac{\partial L}{\partial h_{i+2}} & \frac{\partial h_{i+1}}{\partial h_i} & \frac{\partial h_{i+2}}{\partial h_{i+1}} & \frac{\partial h_i}{\partial x_i}\\ \end{array}

This is a simple combination of a cumulative product, t=iTht+1ht\prod_{t=i}^T \frac{\partial h_{t+1}}{\partial h_t}, and sum of terms derived from that product.

On parallel hardware, the cumulative product can be computed in time O(logT)O(\log T), the multiplication of each term of that cumulative product with the grad output can be computed in O(1)O(1), and the sum reduction in O(logT)O(\log T) time.

Of course, this is just for one term Lxi\frac{\partial L}{\partial x_i}. The sum reduction can be swapped for a cumulative sum reduction, and then scaled by the appropriate hjxj\frac{\partial h_j}{\partial x_j}.

Generalization

It turns out that the computation in both the recurrence and double scan actually did not rely on the associativity of the forward operator. Intuition: the derivative is the best linear approximation at a particular point, meaning even if the forward as not associative, the backward will be. It also turns out that this has been taken advantage of in past work (Wang, Bai, Pekhimenko, 2020). Curious if we can apply this to transformers in some way too! Maybe in the next blog post.

Acknowledgements

Thanks to Songlin Yang and Sasha Rush for catching bugs!