Have you ever wondered: Wait, what if someone explained Mamba worse? Well, buckle up 💺, because today we’re doing that! First we are going to review SSMs1. Then, we’ll check out recent deep learning contributions to them2. And with this, we’ll finally be ready to understand Mamba: the strongest Transformer dominance contender 😱 Mamba surpasses the performance of transformers twice its size while maintaining sub-quadratic memory requirements! 3
1 Short for State Space Models
2 I will mainly focus on Albert Gu, Tri Daoet all line of research.
3 This post assumes you are familiar with how basic RNN, CNN, and Transformer models work.
My thoughts when I first learned about SSMs.
State Space Models (SSMs)
SSMs are traditionally used in control theory to model a dynamic system4 via state variables5 . Given:
4 Fancy way of saying that functions are dependent on time (or that they are sequential)
5 SSMs core ideas are so omnipresent that they have been re-discovered multiple times, resulting into a terminology mess… Control nerds call state variables to what deep learning geeks call hidden states, which is the same as latent variables in the probabilistic world. Potatos, potatoes, potatoisses 🥔
In practice, we’ll always be using discrete data 9 . Thus, we need a way to discretize Equation 1. Turns out the discrete version takes the following form:
9 In the digital era we live in, truly continuous information processing is increasingly rare 🧑🦳. Nowadays, commonly worked with sequences include: series of text tokens, slices of images, or samples of audio waveforms.
Where have closed formulas in terms of , and a new (learnable) step-size parameter 10. These closed formulas depend on the discretization method used, and is one of the key differentiators between distinct SSM architectures.
10 Defines the granularity of the approximation.
Note 1: Discretization rule examples
The S4 blog 3 presents many great discretization interpretations. The one that resonated with me the most, is to numerically approximate the derivative term in Equation 1. Depending on the approximation method used, we get different formulas for 11. Let’s see some examples, shall we? Yes, we shall:
Euler’s method
For instance, we can approximate the ODE numerically using Euler’s method:
Using our notation:
Then:
Thus, if using this approximation:
Nice! Now, given an we can find its discretized form
Trapezoid rule
Similarly (but more accurately), we can approximate the ODE using the trapezoid rule12 . This is the approximation method used in S4 (paper previous to Mamba).
In this case, one gets that:
Interpolation
Interestingly, these numerical approximations results, are in fact function approximations of these results:
Why not use this formulation then? Notice that one needs to balance computational cost and numeric accuracy. For instance in ZOH (Zero-Order Hold) they use , which takes cubic time. Whereas using the bilinear transformation can be computed in linear time.
12 This leads into a bilinear transformation often used in control theory to convert from continuous-time to discrete-time and vice-versa.
11 Again, different formulations of in terms of is the key spice 🌶️ in your discrete SSM implementation.
SSM as a Recurrent Model
So… A hidden state which depends on the previous hidden state and the input … This sounds familiar… It is actually just a linear RNN!
Caution 2
For simplicity, from now onwards, I’ll ignore matrix and . It can be simply seen as a skip connection between input and output (not very interesting stuff).
SSM as a Convolution
Here, the key difference wrt other RNN architectures is that all dependencies are linear13 . This allows us to unravel the multi-step iteration into a one-step convolution14, which yields a massive computation speedup. In general, we obtain this closed form:
13 I can’t stress enough how important this is.
14 If the whole input sequence is available (it is usually available at train time but not at inference time)
Note 3: Unravel derivation
If we take , we have:
Then since the output can be expressed as:
From here it is easy to generalize for the presented
Caution 4: On the stability of this type of thing
Pro life tip: Whenever you see something like where can be big, your alarms should go like this: 🚨
If your matrix is expansive, ie:
or, equivalently, any:
you better forget about it. will explode and everything will diverge. On the contrary, if your function is very contractive, will collapse. 15
In our case, we need a contractive matrix. When computing the last elements of the output , it will collapse the first elements. However, intuitively, most things seen very long ago should have little influence on the current output (at least for now 😉).
15 This is the same problem we have with vanishing and exploding gradients in very deep (or recurrent) neural networks. For DNNs, residual skip connections saved the day by allowing gradients to flow through the network without being affected by expansive or contractive matrices multiple times.
We can extract these coefficients into what is called the SSM kernel:
So, essentially, we can express the mapping between and as this convolution 16 :
16 Note that this is a giant filter (as long as the sequence). And its components are matrices of shape
The kernel containing the matrices moves towards the right.
Note 5: Rambling: On the implementation of long convolutions
The previous visualization of a 1-D convolution is for interpretation purposes.
Notice that to utilize this theorem for non-circular convolutions (as in our case), we need to pad the input sequences with zeros, and then un-pad the output sequence.
So… RNN or CNN?
Alright, so we saw that we can use these SSMs both as a RNN or as a CNN. But how is that useful?
The let’s-not-think-too-much answer is quite simple, use the:
CNN version when you have the whole input sequence available: You usually do have it available at train time. This means that doing both the forward and backward passes are highly-parallelizable one-step operations. In contrast to training RNNs, which is known to be very inefficient17.
RNN version when the whole input sequence is NOT available. Mainly when you are using your model in an auto-regressive way18 or it is ingesting a stream of data19. You usually use this version at inference time. At each time-step you just get from and by applying the recursion expressed in Equation 3 (a couple of matrix multiplications).
17 Mainly because you cannot parallelize the processing of a sequence: you need to iterate one-element-at-a-time, compute the output, store the gradients, and then do the backpropagation through time
18 E.g: performing next-token prediction.
19 E.g: real-time audio transcription.
We want the CNN version to train, and the RNN version to infer.
Important 6: COMPLEXITY ANALYSIS
Notice that these are the same characteristics that make the transformer architecture so powerful20. However, standard transformers require:
Quadratic memory at train-time (wrt input length)
Linear memory at inference time, which is prohibiting for long sequences.
What makes SSMs so attractive is that they work with:
This is all very nice, but (quite unsurprisingly) its vanilla implementation performs very poorly in practice: it struggles modelling non-linear dynamics22. In this section we’ll go through the main ideas from some papers which introduce key improvements 23: HiPPO, S4, DSS, S6 (Mamba).
22 Fancy way of saying that the change in the output is not proportional to a change in the input
23 They are sorted conologically (as they build up on top of each other)
HiPPO
Ok, let’s forget for a moment about SSMs (trust me on this one, I know what I’m doing). Let’s try to solve the following problem instead: How can I compress everything I’ve seen so far into a fixed representation? 24
24 I find the intersection of data compression and machine learning interesting and I explored it a bit in another post in which I ramble around information theory.
HiPPO25 introduces a general framework for the online compression of continuous signals (and discrete time series) by projecting them onto polynomial bases . This framework can easily be integrated into RNNs (such as SSMs). In addition, it is a generalization over previous models such as LSTMs and GRUs, improving on them and achieving SOTA on permuted MNIST (popular continual learning benchmark).
For now, all you need to know is that initializing the SSM matrix with “the HiPPO matrix” makes the model a more effective learner 26:
26 It is observed that simply modifying an SSM from a random matrix to HiPPO improved its performance on the sequential MNIST classification benchmark from 60% to 98%.
In TLDR terms: This matrix makes the model’s hidden state to be a representation of the whole seen input . Which means that the hidden state contains all required information to (approximately) reconstruct . Thus, we are effectively compressing the whole input sequence into a single point.
Note 7: What about the not-so-TLDR terms?
Problem
HiPPO formulates the problem of online function approximation as follows. Given a continuous function:
The goal is to find a fixed representation (vector):
Such that captures the history of
But how? The idea is to project the seen function onto a -dimensional function space spawned by some basis function and store the coefficients of the projection . For this, we need to specify a measure and a basis.
Measure
A measure (or weight / error function) that tells how much we “care” about every time in the past. In the paper, they focus on these two cases:
Translated Legendre Measure (LegT): We only “care” about recent events, so we focus on a sliding window of length from now past-wise (if that word exists).
Scaled Legendre Measure (LegS): We “care” equally about everything we’ve seen so far.
The black line is . The red box represents the measure’s weight at . The blue box represents the measure’s weight at
Visualization of the first six Legendre polynomials (they can easily be obtained by construction from ).
One can calculate the coefficients of the projection of function in terms of Legendre polynomials by solving:
Then, we have that:
However, in this case, we are concerned about the approximation in a dynamic context, where the coefficients change over time (as we “experience” ). Check the paper for details, but it turns out that there exists a closed-form solution to this approximation problem, given by this linear differential equation:
Summary of the HiPPO framework under LegS measure: The projection coefficients evolve through time according to a linear dynamical system. represents the reconstructed function (linear combination of basis elements according to ). The colored boxes represent the used metrics (weight given through time).
Essentially, that the parameter change rate depends linearly by the parameter’s value and . and depend on the chosen metric, and following these dynamics (aka solving the ODE) one can find the coefficients that optimally approximate according to the chosen measure.
For instance, in the measures presented earlier, we obtain the following expressions:
HiPPO LegT Operator:
HiPPO LegS Operator:
Notice this is the state matrix we “pick” as a starting point for our model. Somehow “helping” it to learn the projection of the seen par tof onto Legendre polynomials with scaled measure. 🥳
If curious, this is how the matrix looks for .
Code
# Please don't judge the code, its all chatgptimport numpy as npimport matplotlib.pyplot as pltdef generate_matrix(N): A = np.zeros((N, N)) # Initialize a N x N matrix with zerosfor n inrange(N):for k inrange(N):if n > k: A[n, k] = np.sqrt((2*n +1)*(2*k +1))elif n == k: A[n, k] = n +1else: A[n, k] =0# This line is technically not needed as the matrix is initialized with zerosreturn A# For example, generate a 5x5 matrixN =8matrix = generate_matrix(N)plt.figure(figsize=(9, 7))plt.imshow(matrix, cmap='viridis', aspect='equal')plt.colorbar()plt.title("HiPPO matrix")# Annotate the heatmap with matrix valuesfor i inrange(N):for j inrange(N): text = plt.text(j, i, f"{matrix[i, j]:.2f}", ha="center", va="center", color="w")plt.show()
S4
S427 takes a basic SSM model and adds two key contributions:
27Structured State Space for Sequence Modelling paper (Oct 2021). Note: Structured, because it imposes structure to matrix.
It uses the HiPPO matrix as the “state matrix” . Which they freeze (i.e. it is not learnable28).
It implements a very efficient way to compute the convolution operation. In summary: they do not materialize the full kernel in memory, but use a generating function to obtain the needed value at each moment.
28 Only learnable params are:
Caution 8: This is SISO!
The SSM layer implementation only considers 1-dimensional input-output sequences (ie ). This is known as SISO (Single-Input Single-Output). Notice that in this case is a vector.
To manage multivariate input-output, they stack such layers (one for each dimension 29). They couple it together with nonlinear mixing layers (to break independence assumption).
In the S530paper, they generalize the S4 layer into being multivariate. Not SISO anymore but MIMO (Multi-Input Multi-Output).
30Simplified Structured State Space for Sequence Modelling (March 2023)
29 Resulting into stacked SSMs with different parameters.
To make it apparent that dimensions of input and output pairs are independent, I represent them as “batched”. In this example: input_dim = output_dim = 3, and hidden_dim = 4. Thus, we get 3 kernels of dimension 1xL, which we apply in parallel.
Note 9: Intuition behind the efficient kernel computation
Remember that 1-dimensional SSMs get defined by the matrices where .
The following is an embarrassing over-simplification31 of the logic they follow to optimize the computation of the kernel:
They notice there exists a special case of SSM matrix structure from which computing a truncated generating function of the kernel becomes very fast 32: In particular, if the SSM has a Diagonal Plus Low-Rank (DPLR) decomposition in complex space. This means that the SSM needs to have the following structure:
They notice that NPLR matrices are thus equivalent to DPLR matrices from the perspective of SSM models.
This means we can do the decomposition in complex space and get the generating function of the kernel, so we don’t need to fully materialize it in memory 🥳
33 Actually in this case the rank is 1 because they are vectors.
32 You don’t actually pre-compute the kernel, but a generating function making use of the z-transform
31 I will not open this melon 🍉 because it is very math heavy, but you can find an explanation here.
Note 10: S4 experiments
Standard MNIST classification (Treating each image as a sequence of 784 pixels): As reported: an S4 model of and -layers gets to around 99% accuracy right away.
MNIST next-pixel prediction.Like next token prediction in LMs but with pixel values, given an initial string (prompt) the model should complete the sequence: A model with and -layers ( million params) achieved SOTA according to PapersWithCode. Here is an example of a prediction (sample) and the ground truth (The model was fed the white part, the first 300 pixels and inferred the rest):
QuickDraw next-pixel prediction.Same as before but more images representing more things than just digits: A model with and -layers generated “relatively” coherent completions (prompt was first 500 pixels)
Spoken digits classificationSame as MNIST classification but the input is an audio wave of people reading digits from 0 to 9: The model achieved 97% accuracy from the raw soundwave.
Spoken digits generationThe model gets prompted the begining of the audio wave and it has to continue it: The cherry-picked examples sound very good (even the tone of voice is matched: Check some examples here) . The dataset has around 3k examples of around 6400 steps, at 8kHz sampling rate discretized into 256 classes with -law encoding:
Apparently it also performs well at movie-clip classification: ViS4mer paper. [PERSONAL NOTE]: I like the idea of combining the power of a transformer for fixed-size type of data (frame) with the power of S4 of analyzing time series (sequence of frames).
ViS4mer architecture
DSS
DSS34 shows that one can match S4’s level of performance by using a diagonal state matrix (i.e. dropping the low-rank correction). This massively simplifies the computation and makes its implementation much straight forward.
35 This makes sense if you have read the explanation of S4, otherwise don’t worry too much about it.
Again, the model is very sensitive to parameter initialization. Which they use the eigenvalues of the normal part of: NPLR form of the HiPPO matrix (selecting the ones that ensure a stable system).35
DSS convolution. Same as S4, but using diagonal matrices.
S6 (Mamba)
Enough chit-chat! Let’s address the 🐘 in the room: Our model dynamics36are constant. Look again at Equation 1 and Equation 3, (and thus ) remain fixed through the sequence! This property is known as Linear Time Invariance (LTI), and it is highly limiting in terms of modelling power. However, if we break LTI (by making these matrices variable), we loose the possibility to process the input time series in one step (we loose CNN aspect of the SSM, as the “kernel” is not fixed any more 📉).
36 How things change through time.
37 Mamba: Linear-Time Sequence Modeling with Selective State Spaces paper (Jan 2024)
Mamba37 does exactly that, adding two contributions: Selectivity (make matrices depend on input ) and, as its consequence, the “scan operation” (convolution substitute).
Combined together, they become the selective scan algorithm38.
38 S6 because this is S4 + Selective Scan
Note 11: Why the Mamba name?
Make of this whatever you want:
SSSSSS like the snakes
Selectivity
Mamba breaks SSMs LTI by allowing its parameters 39 be dependent on the current input. We now have:
39 They keep fixed, they use the same one as in DSS.
Where the matrices 40 of the discretized version are computed the same as explained before. But now, we have that:
In the paper, they show the differences from S4 and S6 like so:
They refer to this property as “selectivity”: Because now the model can be more selective about what gets included in the state. Some inputs might be more relevant than others for the given task41, by making dependent on , we can more easily discard non-relevant stuff.42 This way we ensure that not all inputs are treated equally, VIP ones have bigger influence.
41 For instance, in audio transcription it might not be very important to remember that 5 minutes ago there was a long pause. (I talk a lot about transcription because it’s currently paying my bills).
42 Having dependent on has a similar effect, it is like choosing the step size based on how “interesting” an observation is.
Note 12: Why we can’t do a convolution operation? *Visualized
Very simple, the convolution operation is designed for a convolution kernel which is always the same. Here the coefficients of the kernel change at each step depending on the input!
For instance, when processing we have that the first element of the kernel is :
Whereas on the next convolution step, we now have that 😠
In conclusion, we cannot use the convolution operation, so we need to go back to the recurrent version of the model.
The scan operation
Alright, let’s forget about the CNN then. Let’s take a step back and look into the RNN. A naive way of implementing the computation (when the complete sentence is provided)43 would be to do something like this:
43 I’ll focus on the computation of h[t] since getting y[t] from h[t] is a simple matrix multiplication.
T =len(x)h[0] = B_bar[0] * x[0] # This is a vectorfor t inrange(T): h[t] = A_bar[t] * h[t-1] + B_bar[t] * x[t]
An (overcomplicated)44 representation of this computation would be this:
44 Please bare with me, it’ll make understanding the scan operation much simpler.
First we compute B_bar[t] * x[t]. This is a tensor multiplication without previous time-step dependencies, which can easily be parallelized. Then, for the iterative part, at every step t we take the output of the previous step h[t-1], we multiply by A_bar and add B_bar[t] * x[t], which gives us h[t]. Notice that we need 8 steps to compute the output of a sequence of length 8.
This is a very efficient way to compute it in terms of number of operations. However, remember that the hardware (GPUs, TPUs) we use to train this type of models is optimized for high parallelization. And, ultimately, what we care about is time-to-complete the computation. Considering this, is there a way to compute this faster? Yes, there is! Taking inspiration from the parallel scan operation45.
45 This thing has more names than a spanish telenovela character. It is referred to as: cumulative sum, or prefix sum, or inclusive scan or just scan…
46 It is not the most work-efficient one, so it is not what Albert and Tri do, but I think it is useful to get the main idea behind this.
47 Notice that now h is a matrix of shape log_2(T) x T. I add a new dimension to h to store intermediate steps. In the end, we will use the last row.
There exist different implementations of this operation presenting varying properties, check wikipedia. Here I’m going to focus on the simplest one to grasp, just to get the gist of it46. The idea is the following47:
T =len(x)h[0, :] = B_bar[:] * x[:] # Broadcasted for all tfor i inrange( log_2(T) ):for t inrange(T) DO IN PARALLEL:if t <2^i: h[i, t] = h[i-1, t]else: A_acc = A[t] * A[t -1] * ... * A[t -2^i +1] h[i, t] = h[i-1, t] + A_acc * h[i-1, t -2^i]
Using the same diagram as before, we can now represent the computations as follows48:
48 Not so stupid after all to represent it the way I did, is it?
The general idea is to compute the sequence in parts and iteratively combine them thanks to the associative property. Thus, we now only need log_2(T) steps to compute the output! 🥳 (if your machine has at least T processors).
Note 13: If things still don’t click…
It might be useful to see the “path” the computation follows to get a particular input. I acknowledge there is an interesting mess of lines:
As you can see, it depends on all previous inputs and the powers of A are correctly assigned. Notice I nsimplified the problem by assuming a constant for a more clear explanation. The idea is the same with input-dependent .
On top of that, they are very conscious of how to efficiently perform these operations in a GPU49.
Hardware-aware implementation of the scan operation. In this case, we have an input_dim = output_dim = 5, and a hidden_dim = 4. The idea is that they only materialize the hidden states in the most efficient level of memory (GPU SRAM). This avoids the slow operation of copying these values from SRAM into the HBM50 (this slow IO transfer speed is one of the main bottlenecks of current GPUs).
Last remark: When training, they do not store the intermediate states in the forward pass. Instead, they recompute them again in the backward pass, which ends up being more efficient. This idea is also a common trick to reduce the memory footprint at train time51.
Note 14: Rambling: SRAM, HBM and custom CUDA kernels
The optimization in Mamba is somehow in line of the FlashAttention (May 2022) proposed by Tri Dao et all. The idea is to substitute the PyTorch operations (or whichever deep learning framework is being used) by a custom CUDA kernel52 (aka Fused Kernel) which combines and performs them more efficiently:
Common flow of operations done in PyTorch. Image from here.
Flow of operations in a memory-aware fused CUDA kernel. Image from here.
For instance, in the case of FlashAttention, they substitute the omnipresent dot attention operations from the transformer by a single operation.
This (obviously) results into faster and more memory-efficient attention method. The approach not only combines the operations but also better utilizes the GPU memory allocation. For instance, an A100 GPU has:
40GB (or 80 GB) of HBM (High-Bandwidth-Memory): Large but slow: This is implemented by stacking multiple DRAM (Dynamic Random Access Memory) dies allowing high parallel data transfers.
~0.2MB x 108 processors of SRAM (Static Random Access Memory): Small but fast: The speed advantage of SRAM over DRAM comes from SRAM’s ability to hold a given data bit in a static state (on or off) as long as power is supplied. Moreover, it is more reliable than DRAM. DRAM must refresh its stored data bits many times per second to maintain the integrity of the data stored (making operations slower). However, SRAM presents a much higher cost-per-bit and and requires more physical space on the chip. Thus, it is usually reserved for operations where speed and reliability is critical (such as CACHE in CPUs or GPUs). .
52 Fancy way of saying: function that runs on GPU written in CUDA
The mamba block
Similar to a transformer block, they provide a “Mamba block”, which can be stacked to build larger models. It basically adds feature expansion, a CNN and a skip connection around the SSM layer like so:
If you check the official implementation. Specfically, the mamba block, you’ll see this important parameters: d_model (input and output dimensions), d_state (SSM hidden state dim), d_conv (conv-1d kernel size), expand (initial projection expansion factor). I referenced them in the diagram.
And that’s pretty much it 🙃
Epilogue
I’d just like to add that I found the line of research by Albert and Tri extremely inspiring. It’s super cool that such talented people thinking of new approaches to solve temporal series problems (instead of micro-optimzing the transformers). I like to see the results of combining ideas from so many fields: control theory from SSMs, function projections from 🦛, parallel computing and hardware optimization from 🐍
In a next post I’ll explore the relationship between the original transformer, the linear transformer, Mamba, and maybe some other RNN model. I’ll paying special attention to similarties in architecture and time & memory complexities at train/eval times.53.
53 As I was writting this (May 2024), Mamba2 came out and they do a lot of this work haha, so I guess I’ll be summarizing that :)
NoteReferences
To create this blog I “consumed” and summarized these amazing resources: