Differentiating through an associative parallel scan
State-space models and linear (matrix-valued) RNNs
have recently risen greatly in popularity,
thanks to efficient hardware implementations.
In particular, parallel associative scans allow one to compute
reductions of sequences of length T in O(logT) time
on parallel hardware.
Implementing these parallel reductions requires tremendous care:
naively, they would require much more memory than a sequential O(T) scan.
In practice, people get around this via checkpointing and bespoke
implementations of the forward and backward passes.
Manual implementation of backward passes is painful — this is the whole problem
automatic differentiation is supposed to solve!
In this blog post, we’re going to work through
the backwards pass of an associative scan.
This post is inspired by a problem my advisor, Sasha,
ran into when implementing state space models.
Parallel scans
Say we have a sequence of inputs, x1,x2,…,x5, which can be vectors, flattened matrices, etc.
Given a sequence of partial products, i.e. the cumulative product,
ht=xt⊕ht−1=xt⊕xt−1⊕⋯⊕x1,
we want to optimize a loss function
L(h1,…,h5)
via gradient descent.
Therefore, our goal is to efficiently compute the following gradient efficiently:
This is a simple combination of a cumulative product,
∏t=iT∂ht∂ht+1,
and sum of terms derived from that product.
On parallel hardware, the cumulative product can be computed in time O(logT),
the multiplication of each term of that cumulative product with the grad output can be computed in O(1),
and the sum reduction in O(logT) time.
Of course, this is just for one term ∂xi∂L.
The sum reduction can be swapped for a cumulative sum reduction, and then scaled by the appropriate
∂xj∂hj.
Generalization
It turns out that the computation in both the recurrence and double scan
actually did not rely on the associativity of the forward operator.
Intuition: the derivative is the best linear approximation at a particular point,
meaning even if the forward as not associative, the backward will be.
It also turns out that this has been taken advantage of in past work (Wang, Bai, Pekhimenko, 2020).
Curious if we can apply this to transformers in some way too!
Maybe in the next blog post.