HiPPOs šŸ¦›, Mambas šŸ, and other creatures

I go through the research journey that lead into Mamba. I first review SSMs, then explore the models: HiPPO, S4, DSS, and finally Mamba.
Beyond the transformer
Author

Oleguer Canal

Published

May 18, 2024

Have you ever wondered: Wait, what if someone explained Mamba worse? Well, buckle up šŸ’ŗ, because today we’re doing that! First we are going to review SSMs1. Then, we’ll check out recent deep learning contributions to them2. And with this, we’ll finally be ready to understand Mamba: the strongest Transformer dominance contender 😱 Mamba surpasses the performance of transformers twice its size while maintaining sub-quadratic memory requirements! 3

1 Short for State Space Models

2 I will mainly focus on Albert Gu, Tri Dao et all line of research.

3 This post assumes you are familiar with how basic RNN, CNN, and Transformer models work.

My thoughts when I first learned about SSMs.

State Space Models (SSMs)

SSMs are traditionally used in control theory to model a dynamic system 4 via state variables 5 . Given:

4 Fancy way of saying that functions are dependent on time (or that they are sequential)

5 SSMs core ideas are so omnipresent that they have been re-discovered multiple times, resulting into a terminology mess… Control nerds call state variables to what deep learning geeks call hidden states, which is the same as latent variables in the probabilistic world. Potatos, potatoes, potatoisses šŸ„”

  • An N-dim input time series6: \(\space x(t) \in \mathbb{R}^N\)

  • A P-dim output time series7: \(\space y(t) \in \mathbb{R}^P\)

6 Think of a trajectory (aka line) in a \(N\)-dimensional space.

7 Think of another trajectory but now in a \(P\)-dimensional space 🤪

SSMs model the mapping between both time-series with this linear differential equation:

\[ \begin{split} & h^\prime (t) &= A h(t) + B x(t) \\ & y (t) &= C h(t) + D x(t) \end{split} \tag{1}\]

Where:

  • \(h(t) \in \mathbb{R}^H\)

  • \(A \in \mathbb{R}^{H \times H}\)8, \(B \in \mathbb{R}^{H \times N}\), \(C \in \mathbb{R}^{P \times N}\), \(D \in \mathbb{R}^{P \times H}\) are learnable params.

8 Aka ā€œstate matrixā€

Discretization of SSMs

In practice, we’ll always be using discrete data 9 . Thus, we need a way to discretize Equation 1. Turns out the discrete version takes the following form:

9 In the digital era we live in, truly continuous information processing is increasingly rare šŸ§‘ā€šŸ¦³. Nowadays, commonly worked with sequences include: series of text tokens, slices of images, or samples of audio waveforms.

\[ \begin{split} & h_t = \overline{A} h_{t-1} + \overline{B} x_t \\ & y_t = \overline{C} h_t + \overline{D} x_t \end{split} \tag{2}\]

Where \(\overline{A}, \overline{B}, \overline{C}, \overline{D}\) have closed formulas in terms of \(A, B, C, D\), and a new (learnable) step-size parameter \(\Delta \in \mathbb{R}_{>0}\)10. These closed formulas depend on the discretization method used, and is one of the key differentiators between distinct SSM architectures.

10 Defines the granularity of the approximation.

The S4 blog 3 presents many great discretization interpretations. The one that resonated with me the most, is to numerically approximate the derivative term in Equation 1. Depending on the approximation method used, we get different formulas for \(\overline{A}, \overline{B}, \overline{C}, \overline{D}\) 11. Let’s see some examples, shall we? Yes, we shall:

Euler’s method

For instance, we can approximate the ODE numerically using Euler’s method:

\[ f^{\prime} (x) \simeq \frac{f(x + \Delta) - f(x)}{\Delta} \]

Using our notation:

\[ h^{\prime}_{k-1} = \frac{h_k - h_{k-1}}{\Delta} \]

Then:

\[ \begin{split} h_k =& h_{k-1} + \Delta h^{\prime}_{k-1}\\ =& h_{k-1} + \left( A h_{k-1} + B x_k \right)\\ =& \left(I + \Delta A \right) h_{k-1} + \left( \Delta B \right) x_k\\ =&: \overline{A} h_{k-1} + \overline{B} x_k \end{split} \]

Thus, if using this approximation:

\[ \begin{split} &\overline{A} := I + \Delta A \\ &\overline{B} := \Delta B \\ &\overline{C} := C \\ &\overline{D} := D \\ \end{split} \]

Nice! Now, given an \(\text{SSM} = \{A, B, C, D\}\) we can find its discretized form \(\text{SSM}_{\Delta} = \{ \overline{A}, \overline{B}, \overline{C}, \overline{D} \} = \{I + \Delta A, \Delta B, C, D \}\)

Trapezoid rule

Similarly (but more accurately), we can approximate the ODE using the trapezoid rule 12 . This is the approximation method used in S4 (paper previous to Mamba).

In this case, one gets that:

\[ \begin{split} &\overline{A} := \left(I - \frac{\Delta}{2} A \right)^{-1} \left(I + \frac{\Delta}{2} A \right)\\ &\overline{B} := \left(I - \frac{\Delta}{2} A \right)^{-1} \Delta B \\ &\overline{C} := C \\ &\overline{D} := D \\ \end{split} \]

Interpolation

Interestingly, these numerical approximations results, are in fact function approximations of these results:

\[ \begin{split} &\overline{A} := e^{\Delta A} \\ &\overline{B} := A^{-1} \left( e^{\Delta A} - I \right) B \\ &\overline{C} := C \\ &\overline{D} := D \\ \end{split} \]

  • Euler method yields an approximation of: \(e^x \simeq 1 + x\)
  • The trapezoid rule ends up with a first-order Pade approximation: \(e^x \simeq \frac{1 + \frac{x}{2}}{1 - \frac{x}{2}}\).

Why not use this formulation then? Notice that one needs to balance computational cost and numeric accuracy. For instance in ZOH (Zero-Order Hold) they use \(\overline{A} := e^{\Delta A}\), which takes cubic time. Whereas using the bilinear transformation \(\overline{A}\) can be computed in linear time.

12 This leads into a bilinear transformation often used in control theory to convert from continuous-time to discrete-time and vice-versa.

11 Again, different formulations of \(\overline{A}, \overline{B}, \overline{C}, \overline{D}\) in terms of \(A, B, C, D\) is the key spice šŸŒ¶ļø in your discrete SSM implementation.

SSM as a Recurrent Model

So… A hidden state \(h_t\) which depends on the previous hidden state \(h_{{t-1}}\) and the input \(x_t\)… This sounds familiar… It is actually just a linear RNN!

Caution 2

For simplicity, from now onwards, I’ll ignore matrix \(D\) and \(\overline{D}\). It can be simply seen as a skip connection between input and output (not very interesting stuff).

SSM as a Convolution

Here, the key difference wrt other RNN architectures is that all dependencies are linear 13 . This allows us to unravel the multi-step iteration into a one-step convolution14, which yields a massive computation speedup. In general, we obtain this closed form:

13 I can’t stress enough how important this is.

14 If the whole input sequence \(x\) is available (it is usually available at train time but not at inference time)

\[ y_k = \overline{C} \overline{A}^k \overline{B} x_0 + \overline{C} \overline{A}^{k-1} \overline{B} x_1 + \ldots + \overline{C} \overline{A} \overline{B} x_{k-1} + \overline{C} \overline{B} x_k + \overline{D} x_k \]

If we take \(h_{-1} = 0\), we have:

\[ \begin{split} h_0 &= \overline{B} x_0 \\ h_1 &= \overline{A} h_0 + \overline{B} x_1 = \overline{A} \overline{B} x_0 + \overline{B} x_1 \\ h_2 &= \overline{A} h_1 + \overline{B} x_2 = \overline{A}^2 \overline{B} x_0 + \overline{A} \overline{B} x_1 + \overline{B} x_2 \\ &\ldots \end{split} \]

Then since \(y_t = \overline{C} h_t + \overline{D} x_t\) the output can be expressed as:

\[ \begin{split} y_0 &= \overline{C} h_0 + \overline{D} x_0 = \overline{C} \overline{B} x_0 + \overline{D} x_0 \\ y_1 &= \overline{C} h_1 + \overline{D} x_1 = \overline{C} \overline{A} \overline{B} x_0 + \overline{C} \overline{B} x_1 + \overline{D} x_1 \\ y_2 &= \overline{C} h_2 + \overline{D} x_2 = \overline{C} \overline{A}^2 \overline{B} x_0 + \overline{C} \overline{A} \overline{B} x_1 + \overline{C} \overline{B} x_2 + \overline{D} x_2 \\ &\ldots \end{split} \]

From here it is easy to generalize for the presented \(y_k\)

Pro life tip: Whenever you see something like \(\overline{A}^k\) where \(k\) can be big, your alarms should go like this: 🚨

If your matrix \(A\) is expansive, ie:

\[ \exists \alpha > 1 \quad s.t. \quad \forall x \in \mathbb{R}^N_{\neq 0} \quad \Vert Ax \Vert > \alpha \Vert x \Vert \]

or, equivalently, any:

\[ | \lambda_i | > 1 \quad \left( eigen(\overline{A}) = \{ \lambda_i \}_{i} \right) \]

you better forget about it. \(\overline{A}^k\) will explode and everything will diverge. On the contrary, if your function is very contractive, \(\overline{A}^k\) will collapse. 15

In our case, we need a contractive matrix. When computing the last elements of the output \(y_k\), it will collapse the first elements. However, intuitively, most things seen very long ago should have little influence on the current output (at least for now šŸ˜‰).

15 This is the same problem we have with vanishing and exploding gradients in very deep (or recurrent) neural networks. For DNNs, residual skip connections saved the day by allowing gradients to flow through the network without being affected by expansive or contractive matrices multiple times.

We can extract these coefficients into what is called the SSM kernel \(\overline{K}\):

\[ \overline{K} := \left( \overline{C} \overline{A}^k \overline{B} \right)_{i \in [L]} = \left( \overline{C} \overline{A}^{L-1} \overline{B}, \ldots, \overline{C} \overline{A} \overline{B}, \overline{C} \overline{B}, \right) \]

So, essentially, we can express the mapping between \(y\) and \(x\) as this convolution 16 :

16 Note that this is a giant filter (as long as the sequence). And its components are matrices of shape \(p \times n\)

\[ y = \overline{K} * x \]

The kernel containing the matrices moves towards the right.

The previous visualization of a 1-D convolution is for interpretation purposes.

While it works, a more efficient implementation (for long sequences) is to use the convolution theorem with Fast Fourier Transform (FFT). I also recommend checking my notes on convolution vs cross-correlation to clarify some common misconceptions.

The idea is to first multiply the FFTs of the input sequences and then applying an inverse FFT more or less like so:

xd = np.fft.rfft(np.pad(x, (0, K.shape[0])))  # Get n-point discrete fourier transform
Kd = np.fft.rfft(np.pad(K, (0, x.shape[0])))
out = ud * Kd
return np.fft.irfft(out)[: x.shape[0]]  # Inverse fourier transform

Notice that to utilize this theorem for non-circular convolutions (as in our case), we need to pad the input sequences with zeros, and then un-pad the output sequence.

So… RNN or CNN?

Alright, so we saw that we can use these SSMs both as a RNN or as a CNN. But how is that useful?

The let’s-not-think-too-much answer is quite simple, use the:

  • CNN version when you have the whole input sequence available: You usually do have it available at train time. This means that doing both the forward and backward passes are highly-parallelizable one-step operations. In contrast to training RNNs, which is known to be very inefficient17.

  • RNN version when the whole input sequence is NOT available. Mainly when you are using your model in an auto-regressive way18 or it is ingesting a stream of data19. You usually use this version at inference time. At each time-step you just get \(y_k, h_k\) from \(h_{k-1}\) and \(x_k\) by applying the recursion expressed in Equation 3 (a couple of matrix multiplications).

17 Mainly because you cannot parallelize the processing of a sequence: you need to iterate one-element-at-a-time, compute the output, store the gradients, and then do the backpropagation through time

18 E.g: performing next-token prediction.

19 E.g: real-time audio transcription.

We want the CNN version to train, and the RNN version to infer.
Important 6: COMPLEXITY ANALYSIS

Notice that these are the same characteristics that make the transformer architecture so powerful20. However, standard transformers require:

  • Quadratic memory at train-time (wrt input length)
  • Linear memory at inference time, which is prohibiting for long sequences.

What makes SSMs so attractive is that they work with:

  • Linear memory at train time.
  • Constant memory at inference time! 21

21 This is the most important thing of all this. More on this paper and my post on sequence modelling complexity.

20 Besides explicitly implementing attention layers

SSM follow-ups

This is all very nice, but (quite unsurprisingly) its vanilla implementation performs very poorly in practice: it struggles modelling non-linear dynamics22. In this section we’ll go through the main ideas from some papers which introduce key improvements 23: HiPPO, S4, DSS, S6 (Mamba).

22 Fancy way of saying that the change in the output is not proportional to a change in the input

23 They are sorted conologically (as they build up on top of each other)

HiPPO

Ok, let’s forget for a moment about SSMs (trust me on this one, I know what I’m doing). Let’s try to solve the following problem instead: How can I compress everything I’ve seen so far into a fixed representation? 24

24 I find the intersection of data compression and machine learning interesting and I explored it a bit in another post in which I ramble around information theory.

25 High-order Polynomial Projection Operator: HiPPO: Recurrent Memory with Optimal Polynomial Projections paper (August 2020)

HiPPO25 introduces a general framework for the online compression of continuous signals (and discrete time series) by projecting them onto polynomial bases . This framework can easily be integrated into RNNs (such as SSMs). In addition, it is a generalization over previous models such as LSTMs and GRUs, improving on them and achieving SOTA on permuted MNIST (popular continual learning benchmark).

For now, all you need to know is that initializing the SSM matrix \(A\) with ā€œthe HiPPO matrixā€ makes the model a more effective learner 26:

26 It is observed that simply modifying an SSM from a random matrix \(A\) to HiPPO improved its performance on the sequential MNIST classification benchmark from 60% to 98%.

\[ \text{HiPPO matrix:} \space \space \space A_{nk} = \begin{cases} \sqrt{(2n+1)(2k+1)} & \text{if } n > k \\ n + 1 & \text{if } n = k \\ 0 & \text{if } n < k \end{cases} \]

In TLDR terms: This matrix makes the model’s hidden state \(h(t)\) to be a representation of the whole seen input \(x(t)\). Which means that the hidden state \(h(T) \in \mathbb{R}^m\) contains all required information to (approximately) reconstruct \(x(t) \quad \forall t \in [0, T]\). Thus, we are effectively compressing the whole input sequence into a single point.

Problem

HiPPO formulates the problem of online function approximation as follows. Given a continuous function:

\[ f(t) : \mathbb{R} \rightarrow \mathbb{R} \]

The goal is to find a fixed representation (vector):

\[ c(t) \in \mathbb{R}^N \]

Such that \(c(T)\) captures the history of \(f(t) \space \space \forall t \in [0, T]\)

But how? The idea is to project the seen function \(f\) onto a \(N\)-dimensional function space spawned by some basis function and store the coefficients of the projection \(c\). For this, we need to specify a measure and a basis.

Measure

A measure (or weight / error function) \(\mu^{(t)}\) that tells how much we ā€œcareā€ about every time in the past. In the paper, they focus on these two cases:

  • Translated Legendre Measure (LegT): We only ā€œcareā€ about recent events, so we focus on a sliding window of length \(\theta\) from now past-wise (if that word exists).

  • Scaled Legendre Measure (LegS): We ā€œcareā€ equally about everything we’ve seen so far.

The black line is \(f(t)\). The red box represents the measure’s weight at \(t_0\). The blue box represents the measure’s weight at \(t_1\)

Basis

We need to choose a basis of the function space on which to project \(f\). In that paper, they focus on orthogonal polynomials, in particular Legendre polynomials. However, one can project to any basis such as Chebyshev basis, or to Fourier basis (recovering the Fourier Recurrent Unit).

As a refresher, Legendre polynomials satisfy

\[ \int_{-1}^{1} P_{m}(x)P_{n}(x) dx = 0 \quad \text{if } n \neq m.\space \space s.t. P_n (1) = 1 \forall n \]

Visualization of the first six Legendre polynomials (they can easily be obtained by construction from \(P_0 (x) = 1\)).

One can calculate the coefficients \(c_n\) of the projection of function \(f(x) : [0, 1] \rightarrow \mathbb{R}\) in terms of Legendre polynomials by solving:

\[ a_n = \frac{2n+1}{2} \int_{-1}^{1} f(x)P_n(x) \, dx \]

Then, we have that:

\[ f(x) \approx \sum_{n=0}^{N} a_n P_n(x) \]

However, in this case, we are concerned about the approximation in a dynamic context, where the coefficients \(c(t)\) change over time (as we ā€œexperienceā€ \(f(t)\)). Check the paper for details, but it turns out that there exists a closed-form solution to this approximation problem, given by this linear differential equation:

\[ \dot{c}(t) = A(t)c(t) + B(t)f(t) \]

Summary of the HiPPO framework under LegS measure: The projection coefficients \(c\) evolve through time according to a linear dynamical system. \(g(t)\) represents the reconstructed function (linear combination of basis elements according to \(c(t)\)). The colored boxes represent the used metrics (weight given through time).

Essentially, that the parameter change rate depends linearly by the parameter’s value and \(f\). \(A(t)\) and \(B(t)\) depend on the chosen metric, and following these dynamics (aka solving the ODE) one can find the coefficients \(c(t)\) that optimally approximate \(f(t)\) according to the chosen measure.

For instance, in the measures presented earlier, we obtain the following expressions:

  • HiPPO LegT Operator:

\[ \frac{d}{dt} c(t) = -Ac(t) + Bf(t) \]

\[ A_{nk} = \frac{1}{\theta} \begin{cases} (-1)^{n-k}(2n+1) & \text{if } n \geq k, \\ 2n+1 & \text{if } n < k, \end{cases} \quad B_n = \frac{1}{\theta} (-1)^n (2n+1) \]

  • HiPPO LegS Operator:

\[ \frac{d}{dt} c(t) = -\frac{1}{t}Ac(t) + \frac{1}{t}Bf(t) \]

\[ A_{nk} = \begin{cases} (2n+1)^{1/2}(2k+1)^{1/2} & \text{if } n > k, \\ n+1 & \text{if } n = k, \\ 0 & \text{if } n < k, \end{cases} \quad B_n = (2n+1)^{\frac{1}{2}} \]

Notice this is the state matrix we ā€œpickā€ as a starting point for our model. Somehow ā€œhelpingā€ it to learn the projection of the seen par tof \(f(t)\) onto Legendre polynomials with scaled measure. 🄳

If curious, this is how the matrix looks for \(N=8\).

Code
# Please don't judge the code, its all chatgpt
import numpy as np
import matplotlib.pyplot as plt


def generate_matrix(N):
    A = np.zeros((N, N))  # Initialize a N x N matrix with zeros

    for n in range(N):
        for k in range(N):
            if n > k:
                A[n, k] = np.sqrt((2*n + 1)*(2*k + 1))
            elif n == k:
                A[n, k] = n + 1
            else:
                A[n, k] = 0  # This line is technically not needed as the matrix is initialized with zeros

    return A

# For example, generate a 5x5 matrix
N = 8
matrix = generate_matrix(N)

plt.figure(figsize=(9, 7))
plt.imshow(matrix, cmap='viridis', aspect='equal')
plt.colorbar()
plt.title("HiPPO matrix")

# Annotate the heatmap with matrix values
for i in range(N):
    for j in range(N):
        text = plt.text(j, i, f"{matrix[i, j]:.2f}",
                       ha="center", va="center", color="w")

plt.show()

S4

S427 takes a basic SSM model and adds two key contributions:

27 Structured State Space for Sequence Modelling paper (Oct 2021). Note: Structured, because it imposes structure to \(A\) matrix.

  • It uses the HiPPO matrix as the ā€œstate matrixā€ \(A\). Which they freeze (i.e. it is not learnable28).

  • It implements a very efficient way to compute the convolution operation. In summary: they do not materialize the full kernel in memory, but use a generating function to obtain the needed value at each moment.

28 Only learnable params are: \(B, C, \Delta, D\)

Caution 8: This is SISO!

The SSM layer implementation only considers 1-dimensional input-output sequences (ie \(x_t, y_t \in \mathbb{R}^1\)). This is known as SISO (Single-Input Single-Output). Notice that in this case \(\overline{K} \in \mathbb{R}^{1 \times L}\) is a vector.

To manage multivariate input-output, they stack \(N\) such layers (one for each dimension 29). They couple it together with nonlinear mixing layers (to break independence assumption).

In the S530 paper, they generalize the S4 layer into being multivariate. Not SISO anymore but MIMO (Multi-Input Multi-Output).

30 Simplified Structured State Space for Sequence Modelling (March 2023)

29 Resulting into \(N\) stacked SSMs with different parameters.

To make it apparent that dimensions of input and output pairs are independent, I represent them as ā€œbatchedā€. In this example: input_dim = output_dim = 3, and hidden_dim = 4. Thus, we get 3 kernels of dimension 1xL, which we apply in parallel.

Remember that 1-dimensional SSMs get defined by the matrices \(\left(A, B, C \right)\) where \(A \in \mathbb{R}^{N \times N}; B, C \in \mathbb{R}^{N \times 1})\).

The following is an embarrassing over-simplification31 of the logic they follow to optimize the computation of the kernel:

  1. They notice there exists a special case of SSM matrix structure from which computing a truncated generating function of the kernel \(\overline{K}\) becomes very fast 32: In particular, if the SSM has a Diagonal Plus Low-Rank (DPLR) decomposition in complex space. This means that the SSM needs to have the following structure:

\[ \left(\Lambda - P Q^*, B, C \right) \]

  • Where:
    • \(\Lambda \in \mathbb{C}^{N\times N}\) is diagonal matrix
    • \(P, Q \in \mathbb{C}^{N \times 1}\) are low-rank factorization correction 33.
    • \(B, C \in \mathbb{C}^{N \times 1}\)
    • \(Q^*\) is the conjugate transpose of matrix \(Q\)
  1. They notice that HiPPO’s matrix isn’t DPLR, but Normal Plus Low Rank (NPLR).
  2. Normal matrices are unitarily diagonalizable:

\[ \begin{split} A &= V \Lambda V^* - P Q^T \\ &= V (\Lambda - V^* P (V^* Q)* ) V^* \end{split} \]

  • Where:
    • \(V \in \mathbb{C}^{N\times N}\) is unitary
    • \(\Lambda \in \mathbb{C}^{N\times N}\) is diagonal
    • \(P, Q \in \mathbb{R}^{N \times r}\) are low-rank factorization correction.
  1. They notice that NPLR matrices are thus equivalent to DPLR matrices from the perspective of SSM models.

  2. This means we can do the decomposition in complex space and get the generating function of the kernel, so we don’t need to fully materialize it in memory 🄳

33 Actually in this case the rank is 1 because they are vectors.

32 You don’t actually pre-compute the kernel, but a generating function making use of the z-transform

31 I will not open this melon šŸ‰ because it is very math heavy, but you can find an explanation here.

  • Standard MNIST classification (Treating each image as a sequence of 784 pixels): As reported: an S4 model of \(H=256\) and \(4\)-layers gets to around 99% accuracy right away.

  • MNIST next-pixel prediction. Like next token prediction in LMs but with pixel values, given an initial string (prompt) the model should complete the sequence: A model with \(H=512\) and \(6\)-layers (\(~4\) million params) achieved SOTA according to PapersWithCode. Here is an example of a prediction (sample) and the ground truth (The model was fed the white part, the first 300 pixels and inferred the rest):

  • QuickDraw next-pixel prediction. Same as before but more images representing more things than just digits: A model with \(H=256\) and \(4\)-layers generated ā€œrelativelyā€ coherent completions (prompt was first 500 pixels)

  • Spoken digits classification Same as MNIST classification but the input is an audio wave of people reading digits from 0 to 9: The model achieved 97% accuracy from the raw soundwave.

  • Spoken digits generation The model gets prompted the begining of the audio wave and it has to continue it: The cherry-picked examples sound very good (even the tone of voice is matched: Check some examples here) \(H=512\). The dataset has around 3k examples of around 6400 steps, at 8kHz sampling rate discretized into 256 classes with \(\mu\)-law encoding:

orange=prompted part, blue=generated (or real) part.
  • Apparently it also performs well at movie-clip classification: ViS4mer paper. [PERSONAL NOTE]: I like the idea of combining the power of a transformer for fixed-size type of data (frame) with the power of S4 of analyzing time series (sequence of frames).

ViS4mer architecture

DSS

DSS34 shows that one can match S4’s level of performance by using a diagonal state matrix (i.e. dropping the low-rank correction). This massively simplifies the computation and makes its implementation much straight forward.

34 Diagonal State Spaces paper (March 2022)

35 This makes sense if you have read the explanation of S4, otherwise don’t worry too much about it.

Again, the model is very sensitive to parameter initialization. Which they use the eigenvalues of the normal part of: NPLR form of the HiPPO matrix (selecting the ones that ensure a stable system).35

DSS convolution. Same as S4, but using diagonal \(\overline{A}\) matrices.

S6 (Mamba)

Enough chit-chat! Let’s address the 🐘 in the room: Our model dynamics36 are constant. Look again at Equation 1 and Equation 3, \(A, B, C, \Delta\) (and thus \(\overline{A}, \overline{B}\)) remain fixed through the sequence! This property is known as Linear Time Invariance (LTI), and it is highly limiting in terms of modelling power. However, if we break LTI (by making these matrices variable), we loose the possibility to process the input time series in one step (we loose CNN aspect of the SSM, as the ā€œkernelā€ is not fixed any more šŸ“‰).

36 How things change through time.

37 Mamba: Linear-Time Sequence Modeling with Selective State Spaces paper (Jan 2024)

Mamba37 does exactly that, adding two contributions: Selectivity (make \(B, C, \Delta\) matrices depend on input \(x_t\)) and, as its consequence, the ā€œscan operationā€ (convolution substitute).

Combined together, they become the selective scan algorithm38.

38 S6 because this is S4 + Selective Scan

Make of this whatever you want:

SSSSSS like the snakes

Selectivity

Mamba breaks SSMs LTI by allowing its parameters \(B, C, \Delta\)39 be dependent on the current input. We now have:

39 They keep \(A\) fixed, they use the same one as in DSS.

\[ \begin{split} & h_t = \overline{A}_t h_{t-1} + \overline{B}_t x_t \\ & y_t = \overline{C}_t h_t \end{split} \tag{3}\]

Where the matrices \(\overline{A}, \overline{B}\)40 of the discretized version are computed the same as explained before. But now, we have that:

40 Remember that \(\overline{C} = C\) (seen in Tip 1).

\[ \begin{split} &B_t = \text{LINEAR}_B (x_t)\\ &C_t = \text{LINEAR}_C (x_t)\\ &\Delta_t = \text{LINEAR}_{\Delta} (x_t)\\ \end{split} \]

In the paper, they show the differences from S4 and S6 like so:

They refer to this property as ā€œselectivityā€: Because now the model can be more selective about what gets included in the state. Some inputs \(x_t\) might be more relevant than others for the given task41, by making \(B\) dependent on \(x\), we can more easily discard non-relevant stuff.42 This way we ensure that not all inputs are treated equally, VIP ones have bigger influence.

41 For instance, in audio transcription it might not be very important to remember that 5 minutes ago there was a long pause. (I talk a lot about transcription because it’s currently paying my bills).

42 Having \(\Delta\) dependent on \(x\) has a similar effect, it is like choosing the step size based on how ā€œinterestingā€ an observation is.

Very simple, the convolution operation is designed for a convolution kernel which is always the same. Here the coefficients of the kernel change at each step depending on the input!

For instance, when processing \(y_1\) we have that the first element of the kernel is \(k_1 = \overline{C}^T_1 \overline{B}_1\):

Whereas on the next convolution step, we now have that \(k_1 = \overline{C}^T_2 \overline{B}_2\) 😠

In conclusion, we cannot use the convolution operation, so we need to go back to the recurrent version of the model.

The scan operation

Alright, let’s forget about the CNN then. Let’s take a step back and look into the RNN. A naive way of implementing the computation (when the complete sentence is provided)43 would be to do something like this:

43 I’ll focus on the computation of h[t] since getting y[t] from h[t] is a simple matrix multiplication.

T = len(x)

h[0] = B_bar[0] * x[0]  # This is a vector
for t in range(T):
    h[t] = A_bar[t] * h[t-1] + B_bar[t] * x[t]

An (overcomplicated)44 representation of this computation would be this:

44 Please bare with me, it’ll make understanding the scan operation much simpler.

First we compute B_bar[t] * x[t] \(\forall t\). This is a tensor multiplication without previous time-step dependencies, which can easily be parallelized. Then, for the iterative part, at every step t we take the output of the previous step h[t-1], we multiply by A_bar and add B_bar[t] * x[t], which gives us h[t]. Notice that we need 8 steps to compute the output of a sequence of length 8.

This is a very efficient way to compute it in terms of number of operations. However, remember that the hardware (GPUs, TPUs) we use to train this type of models is optimized for high parallelization. And, ultimately, what we care about is time-to-complete the computation. Considering this, is there a way to compute this faster? Yes, there is! Taking inspiration from the parallel scan operation45.

45 This thing has more names than a spanish telenovela character. It is referred to as: cumulative sum, or prefix sum, or inclusive scan or just scan…

46 It is not the most work-efficient one, so it is not what Albert and Tri do, but I think it is useful to get the main idea behind this.

47 Notice that now h is a matrix of shape log_2(T) x T. I add a new dimension to h to store intermediate steps. In the end, we will use the last row.

There exist different implementations of this operation presenting varying properties, check wikipedia. Here I’m going to focus on the simplest one to grasp, just to get the gist of it46. The idea is the following47:

T = len(x)

h[0, :] = B_bar[:] * x[:]  # Broadcasted for all t
for i in range( log_2(T) ):
    for t in range(T) DO IN PARALLEL:
        if t < 2^i:
            h[i, t] = h[i-1, t]
        else:
            A_acc = A[t] * A[t - 1] * ... * A[t - 2^i + 1]
            h[i, t] =  h[i-1, t] + A_acc * h[i-1, t - 2^i]

Using the same diagram as before, we can now represent the computations as follows48:

48 Not so stupid after all to represent it the way I did, is it?

The general idea is to compute the sequence in parts and iteratively combine them thanks to the associative property. Thus, we now only need log_2(T) steps to compute the output! 🄳 (if your machine has at least T processors).

It might be useful to see the ā€œpathā€ the computation follows to get a particular input. I acknowledge there is an interesting mess of lines:

As you can see, it depends on all previous inputs and the powers of A are correctly assigned. Notice I nsimplified the problem by assuming a constant \(A\) for a more clear explanation. The idea is the same with input-dependent \(A_t\).

On top of that, they are very conscious of how to efficiently perform these operations in a GPU49.

49 After all, these guys are the ones who came up with FlashAttention (May 2022)

Hardware-aware implementation of the scan operation. In this case, we have an input_dim = output_dim = 5, and a hidden_dim = 4. The idea is that they only materialize the hidden states in the most efficient level of memory (GPU SRAM). This avoids the slow operation of copying these values from SRAM into the HBM50 (this slow IO transfer speed is one of the main bottlenecks of current GPUs).

50 More about fused CUDA kernels here: Tip 14.

Last remark: When training, they do not store the intermediate states in the forward pass. Instead, they recompute them again in the backward pass, which ends up being more efficient. This idea is also a common trick to reduce the memory footprint at train time51.

51 This has many names, such as: gradient checkpointing, or rematerialization. f interested check out: Training Deep Nets with Sublinear Memory Cost, or this blog post

The optimization in Mamba is somehow in line of the FlashAttention (May 2022) proposed by Tri Dao et all. The idea is to substitute the PyTorch operations (or whichever deep learning framework is being used) by a custom CUDA kernel52 (aka Fused Kernel) which combines and performs them more efficiently:

Common flow of operations done in PyTorch. Image from here.

Flow of operations in a memory-aware fused CUDA kernel. Image from here.

For instance, in the case of FlashAttention, they substitute the omnipresent dot attention operations from the transformer by a single operation.

\[ \text{SOFTMAX} \left( \frac{Q K^T}{ \sqrt{d_k} } \right) V \]

This (obviously) results into faster and more memory-efficient attention method. The approach not only combines the operations but also better utilizes the GPU memory allocation. For instance, an A100 GPU has:

  • 40GB (or 80 GB) of HBM (High-Bandwidth-Memory): Large but slow: This is implemented by stacking multiple DRAM (Dynamic Random Access Memory) dies allowing high parallel data transfers.

  • ~0.2MB x 108 processors of SRAM (Static Random Access Memory): Small but fast: The speed advantage of SRAM over DRAM comes from SRAM’s ability to hold a given data bit in a static state (on or off) as long as power is supplied. Moreover, it is more reliable than DRAM. DRAM must refresh its stored data bits many times per second to maintain the integrity of the data stored (making operations slower). However, SRAM presents a much higher cost-per-bit and and requires more physical space on the chip. Thus, it is usually reserved for operations where speed and reliability is critical (such as CACHE in CPUs or GPUs). .

52 Fancy way of saying: function that runs on GPU written in CUDA

The mamba block

Similar to a transformer block, they provide a ā€œMamba blockā€, which can be stacked to build larger models. It basically adds feature expansion, a CNN and a skip connection around the SSM layer like so:

If you check the official implementation. Specfically, the mamba block, you’ll see this important parameters: d_model (input and output dimensions), d_state (SSM hidden state dim), d_conv (conv-1d kernel size), expand (initial projection expansion factor). I referenced them in the diagram.

And that’s pretty much it šŸ™ƒ

Epilogue

I’d just like to add that I found the line of research by Albert and Tri extremely inspiring. It’s super cool that such talented people thinking of new approaches to solve temporal series problems (instead of micro-optimzing the transformers). I like to see the results of combining ideas from so many fields: control theory from SSMs, function projections from šŸ¦›, parallel computing and hardware optimization from šŸ

In a next post I’ll explore the relationship between the original transformer, the linear transformer, Mamba, and maybe some other RNN model. I’ll paying special attention to similarties in architecture and time & memory complexities at train/eval times.53.

53 As I was writting this (May 2024), Mamba2 came out and they do a lot of this work haha, so I guess I’ll be summarizing that :)