HiPPOs 🦛, Mambas 🐍, and other creatures

I go through the research journey that lead into Mamba. I first review SSMs, then explore the models: HiPPO, S4, DSS, and finally Mamba.
Author

Oleguer Canal

Published

May 18, 2024

Have you ever wondered: Wait, what if someone explained Mamba worse? Well, buckle up 💺, because today we’re doing that! First we are going to review SSMs. Then, we’ll check out recent deep learning contributions to them. And with this, we’ll finally be ready to understand Mamba: the strongest Transformer dominance contender 😱 Mamba surpasses the performance of transformers twice its size while maintaining sub-quadratic memory requirements!

1 Short for State Space Models

2 I will mainly focus on Albert Gu, Tri Dao et all line of research.

3 This post assumes you are familiar with how basic RNN, CNN, and Transformer models work.

My thoughts when I first learned about SSMs.

State Space Models (SSMs)

SSMs are traditionally used in control theory to model a dynamic system via state variables . Given:

4 Fancy way of saying that functions are dependent on time (or that they are sequential)

5 SSMs core ideas are so omnipresent that they have been re-discovered multiple times, resulting into a terminology mess… Control nerds call state variables to what deep learning geeks call hidden states, which is the same as latent variables in the probabilistic world. Potatos, potatoes, potatoisses 🥔

  • An N-dim input time series:  x(t)RN

  • A P-dim output time series:  y(t)RP

6 Think of a trajectory (aka line) in a N-dimensional space.

7 Think of another trajectory but now in a P-dimensional space 🤪

SSMs model the mapping between both time-series with this linear differential equation:

(1)h(t)=Ah(t)+Bx(t)y(t)=Ch(t)+Dx(t)

Where:

  • h(t)RH

  • ARH×H, BRH×N, CRP×N, DRP×H are learnable params.

8 Aka “state matrix

Discretization of SSMs

In practice, we’ll always be using discrete data . Thus, we need a way to discretize . Turns out the discrete version takes the following form:

9 In the digital era we live in, truly continuous information processing is increasingly rare 🧑‍🦳. Nowadays, commonly worked with sequences include: series of text tokens, slices of images, or samples of audio waveforms.

(2)ht=Aht1+Bxtyt=Cht+Dxt

Where A,B,C,D have closed formulas in terms of A,B,C,D, and a new (learnable) step-size parameter ΔR>0. These closed formulas depend on the discretization method used, and is one of the key differentiators between distinct SSM architectures.

10 Defines the granularity of the approximation.

The S4 blog 3 presents many great discretization interpretations. The one that resonated with me the most, is to numerically approximate the derivative term in . Depending on the approximation method used, we get different formulas for A,B,C,D . Let’s see some examples, shall we? Yes, we shall:

Euler’s method

For instance, we can approximate the ODE numerically using Euler’s method:

f(x)f(x+Δ)f(x)Δ

Using our notation:

hk1=hkhk1Δ

Then:

hk=hk1+Δhk1=hk1+(Ahk1+Bxk)=(I+ΔA)hk1+(ΔB)xk=:Ahk1+Bxk

Thus, if using this approximation:

A:=I+ΔAB:=ΔBC:=CD:=D

Nice! Now, given an SSM={A,B,C,D} we can find its discretized form SSMΔ={A,B,C,D}={I+ΔA,ΔB,C,D}

Trapezoid rule

Similarly (but more accurately), we can approximate the ODE using the trapezoid rule . This is the approximation method used in S4 (paper previous to Mamba).

In this case, one gets that:

A:=(IΔ2A)1(I+Δ2A)B:=(IΔ2A)1ΔBC:=CD:=D

Interpolation

Interestingly, these numerical approximations results, are in fact function approximations of these results:

A:=eΔAB:=A1(eΔAI)BC:=CD:=D

  • Euler method yields an approximation of: ex1+x
  • The trapezoid rule ends up with a first-order Pade approximation: ex1+x21x2.

Why not use this formulation then? Notice that one needs to balance computational cost and numeric accuracy. For instance in ZOH (Zero-Order Hold) they use A:=eΔA, which takes cubic time. Whereas using the bilinear transformation A can be computed in linear time.

12 This leads into a bilinear transformation often used in control theory to convert from continuous-time to discrete-time and vice-versa.

11 Again, different formulations of A,B,C,D in terms of A,B,C,D is the key spice 🌶️ in your discrete SSM implementation.

SSM as a Recurrent Model

So… A hidden state ht which depends on the previous hidden state ht1 and the input xt… This sounds familiar… It is actually just a linear RNN!

Caution 2

For simplicity, from now onwards, I’ll ignore matrix D and D. It can be simply seen as a skip connection between input and output (not very interesting stuff).

SSM as a Convolution

Here, the key difference wrt other RNN architectures is that all dependencies are linear . This allows us to unravel the multi-step iteration into a one-step convolution, which yields a massive computation speedup. In general, we obtain this closed form:

13 I can’t stress enough how important this is.

14 If the whole input sequence x is available (it is usually available at train time but not at inference time)

yk=CAkBx0+CAk1Bx1++CABxk1+CBxk+Dxk

If we take h1=0, we have:

h0=Bx0h1=Ah0+Bx1=ABx0+Bx1h2=Ah1+Bx2=A2Bx0+ABx1+Bx2

Then since yt=Cht+Dxt the output can be expressed as:

y0=Ch0+Dx0=CBx0+Dx0y1=Ch1+Dx1=CABx0+CBx1+Dx1y2=Ch2+Dx2=CA2Bx0+CABx1+CBx2+Dx2

From here it is easy to generalize for the presented yk

Pro life tip: Whenever you see something like Ak where k can be big, your alarms should go like this: 🚨

If your matrix A is expansive, ie:

α>1s.t.xR0NAx>αx

or, equivalently, any:

|λi|>1(eigen(A)={λi}i)

you better forget about it. Ak will explode and everything will diverge. On the contrary, if your function is very contractive, Ak will collapse.

In our case, we need a contractive matrix. When computing the last elements of the output yk, it will collapse the first elements. However, intuitively, most things seen very long ago should have little influence on the current output (at least for now 😉).

15 This is the same problem we have with vanishing and exploding gradients in very deep (or recurrent) neural networks. For DNNs, residual skip connections saved the day by allowing gradients to flow through the network without being affected by expansive or contractive matrices multiple times.

We can extract these coefficients into what is called the SSM kernel K:

K:=(CAkB)i[L]=(CAL1B,,CAB,CB,)

So, essentially, we can express the mapping between y and x as this convolution :

16 Note that this is a giant filter (as long as the sequence). And its components are matrices of shape p×n

y=Kx

The kernel containing the matrices moves towards the right.

The previous visualization of a 1-D convolution is for interpretation purposes.

While it works, a more efficient implementation (for long sequences) is to use the convolution theorem with Fast Fourier Transform (FFT). I also recommend checking my notes on convolution vs cross-correlation to clarify some common misconceptions.

The idea is to first multiply the FFTs of the input sequences and then applying an inverse FFT more or less like so:

xd = np.fft.rfft(np.pad(x, (0, K.shape[0])))  # Get n-point discrete fourier transform
Kd = np.fft.rfft(np.pad(K, (0, x.shape[0])))
out = ud * Kd
return np.fft.irfft(out)[: x.shape[0]]  # Inverse fourier transform

Notice that to utilize this theorem for non-circular convolutions (as in our case), we need to pad the input sequences with zeros, and then un-pad the output sequence.

So… RNN or CNN?

Alright, so we saw that we can use these SSMs both as a RNN or as a CNN. But how is that useful?

The let’s-not-think-too-much answer is quite simple, use the:

  • CNN version when you have the whole input sequence available: You usually do have it available at train time. This means that doing both the forward and backward passes are highly-parallelizable one-step operations. In contrast to training RNNs, which is known to be very inefficient.

  • RNN version when the whole input sequence is NOT available. Mainly when you are using your model in an auto-regressive way or it is ingesting a stream of data. You usually use this version at inference time. At each time-step you just get yk,hk from hk1 and xk by applying the recursion expressed in (a couple of matrix multiplications).

17 Mainly because you cannot parallelize the processing of a sequence: you need to iterate one-element-at-a-time, compute the output, store the gradients, and then do the backpropagation through time

18 E.g: performing next-token prediction.

19 E.g: real-time audio transcription.

We want the CNN version to train, and the RNN version to infer.
Important 6: COMPLEXITY ANALYSIS

Notice that these are the same characteristics that make the transformer architecture so powerful. However, standard transformers require:

  • Quadratic memory at train-time (wrt input length)
  • Linear memory at inference time, which is prohibiting for long sequences.

What makes SSMs so attractive is that they work with:

  • Linear memory at train time.
  • Constant memory at inference time!

21 This is the most important thing of all this. More on this paper and my post on sequence modelling complexity.

20 Besides explicitly implementing attention layers

SSM follow-ups

This is all very nice, but (quite unsurprisingly) its vanilla implementation performs very poorly in practice: it struggles modelling non-linear dynamics. In this section we’ll go through the main ideas from some papers which introduce key improvements : HiPPO, S4, DSS, S6 (Mamba).

22 Fancy way of saying that the change in the output is not proportional to a change in the input

23 They are sorted conologically (as they build up on top of each other)

HiPPO

Ok, let’s forget for a moment about SSMs (trust me on this one, I know what I’m doing). Let’s try to solve the following problem instead: How can I compress everything I’ve seen so far into a fixed representation?

24 I find the intersection of data compression and machine learning interesting and I explored it a bit in another post in which I ramble around information theory.

25 High-order Polynomial Projection Operator: HiPPO: Recurrent Memory with Optimal Polynomial Projections paper (August 2020)

HiPPO introduces a general framework for the online compression of continuous signals (and discrete time series) by projecting them onto polynomial bases . This framework can easily be integrated into RNNs (such as SSMs). In addition, it is a generalization over previous models such as LSTMs and GRUs, improving on them and achieving SOTA on permuted MNIST (popular continual learning benchmark).

For now, all you need to know is that initializing the SSM matrix A with “the HiPPO matrix” makes the model a more effective learner :

26 It is observed that simply modifying an SSM from a random matrix A to HiPPO improved its performance on the sequential MNIST classification benchmark from 60% to 98%.

HiPPO matrix:   Ank={(2n+1)(2k+1)if n>kn+1if n=k0if n<k

In TLDR terms: This matrix makes the model’s hidden state h(t) to be a representation of the whole seen input x(t). Which means that the hidden state h(T)Rm contains all required information to (approximately) reconstruct x(t)t[0,T]. Thus, we are effectively compressing the whole input sequence into a single point.

Problem

HiPPO formulates the problem of online function approximation as follows. Given a continuous function:

f(t):RR

The goal is to find a fixed representation (vector):

c(t)RN

Such that c(T) captures the history of f(t)  t[0,T]

But how? The idea is to project the seen function f onto a N-dimensional function space spawned by some basis function and store the coefficients of the projection c. For this, we need to specify a measure and a basis.

Measure

A measure (or weight / error function) μ(t) that tells how much we “care” about every time in the past. In the paper, they focus on these two cases:

  • Translated Legendre Measure (LegT): We only “care” about recent events, so we focus on a sliding window of length θ from now past-wise (if that word exists).

  • Scaled Legendre Measure (LegS): We “care” equally about everything we’ve seen so far.

The black line is f(t). The red box represents the measure’s weight at t0. The blue box represents the measure’s weight at t1

Basis

We need to choose a basis of the function space on which to project f. In that paper, they focus on orthogonal polynomials, in particular Legendre polynomials. However, one can project to any basis such as Chebyshev basis, or to Fourier basis (recovering the Fourier Recurrent Unit).

As a refresher, Legendre polynomials satisfy

11Pm(x)Pn(x)dx=0if nm.  s.t.Pn(1)=1n

Visualization of the first six Legendre polynomials (they can easily be obtained by construction from P0(x)=1).

One can calculate the coefficients cn of the projection of function f(x):[0,1]R in terms of Legendre polynomials by solving:

an=2n+1211f(x)Pn(x)dx

Then, we have that:

f(x)n=0NanPn(x)

However, in this case, we are concerned about the approximation in a dynamic context, where the coefficients c(t) change over time (as we “experience” f(t)). Check the paper for details, but it turns out that there exists a closed-form solution to this approximation problem, given by this linear differential equation:

c˙(t)=A(t)c(t)+B(t)f(t)

Summary of the HiPPO framework under LegS measure: The projection coefficients c evolve through time according to a linear dynamical system. g(t) represents the reconstructed function (linear combination of basis elements according to c(t)). The colored boxes represent the used metrics (weight given through time).

Essentially, that the parameter change rate depends linearly by the parameter’s value and f. A(t) and B(t) depend on the chosen metric, and following these dynamics (aka solving the ODE) one can find the coefficients c(t) that optimally approximate f(t) according to the chosen measure.

For instance, in the measures presented earlier, we obtain the following expressions:

  • HiPPO LegT Operator:

ddtc(t)=Ac(t)+Bf(t)

Ank=1θ{(1)nk(2n+1)if nk,2n+1if n<k,Bn=1θ(1)n(2n+1)

  • HiPPO LegS Operator:

ddtc(t)=1tAc(t)+1tBf(t)

Ank={(2n+1)1/2(2k+1)1/2if n>k,n+1if n=k,0if n<k,Bn=(2n+1)12

Notice this is the state matrix we “pick” as a starting point for our model. Somehow “helping” it to learn the projection of the seen par tof f(t) onto Legendre polynomials with scaled measure. 🥳

If curious, this is how the matrix looks for N=8.

Code
# Please don't judge the code, its all chatgpt
import numpy as np
import matplotlib.pyplot as plt


def generate_matrix(N):
    A = np.zeros((N, N))  # Initialize a N x N matrix with zeros

    for n in range(N):
        for k in range(N):
            if n > k:
                A[n, k] = np.sqrt((2*n + 1)*(2*k + 1))
            elif n == k:
                A[n, k] = n + 1
            else:
                A[n, k] = 0  # This line is technically not needed as the matrix is initialized with zeros

    return A

# For example, generate a 5x5 matrix
N = 8
matrix = generate_matrix(N)

plt.figure(figsize=(9, 7))
plt.imshow(matrix, cmap='viridis', aspect='equal')
plt.colorbar()
plt.title("HiPPO matrix")

# Annotate the heatmap with matrix values
for i in range(N):
    for j in range(N):
        text = plt.text(j, i, f"{matrix[i, j]:.2f}",
                       ha="center", va="center", color="w")

plt.show()

S4

S4 takes a basic SSM model and adds two key contributions:

27 Structured State Space for Sequence Modelling paper (Oct 2021). Note: Structured, because it imposes structure to A matrix.

  • It uses the HiPPO matrix as the “state matrix” A. Which they freeze (i.e. it is not learnable).

  • It implements a very efficient way to compute the convolution operation. In summary: they do not materialize the full kernel in memory, but use a generating function to obtain the needed value at each moment.

28 Only learnable params are: B,C,Δ,D

Caution 8: This is SISO!

The SSM layer implementation only considers 1-dimensional input-output sequences (ie xt,ytR1). This is known as SISO (Single-Input Single-Output). Notice that in this case KR1×L is a vector.

To manage multivariate input-output, they stack N such layers (one for each dimension ). They couple it together with nonlinear mixing layers (to break independence assumption).

In the S5 paper, they generalize the S4 layer into being multivariate. Not SISO anymore but MIMO (Multi-Input Multi-Output).

30 Simplified Structured State Space for Sequence Modelling (March 2023)

29 Resulting into N stacked SSMs with different parameters.

To make it apparent that dimensions of input and output pairs are independent, I represent them as “batched”. In this example: input_dim = output_dim = 3, and hidden_dim = 4. Thus, we get 3 kernels of dimension 1xL, which we apply in parallel.

Remember that 1-dimensional SSMs get defined by the matrices (A,B,C) where ARN×N;B,CRN×1).

The following is an embarrassing over-simplification of the logic they follow to optimize the computation of the kernel:

  1. They notice there exists a special case of SSM matrix structure from which computing a truncated generating function of the kernel K becomes very fast : In particular, if the SSM has a Diagonal Plus Low-Rank (DPLR) decomposition in complex space. This means that the SSM needs to have the following structure:

(ΛPQ,B,C)

  1. They notice that HiPPO’s matrix isn’t DPLR, but Normal Plus Low Rank (NPLR).
  2. Normal matrices are unitarily diagonalizable:

A=VΛVPQT=V(ΛVP(VQ))V

  • Where:
    • VCN×N is unitary
    • ΛCN×N is diagonal
    • P,QRN×r are low-rank factorization correction.
  1. They notice that NPLR matrices are thus equivalent to DPLR matrices from the perspective of SSM models.

  2. This means we can do the decomposition in complex space and get the generating function of the kernel, so we don’t need to fully materialize it in memory 🥳

33 Actually in this case the rank is 1 because they are vectors.

32 You don’t actually pre-compute the kernel, but a generating function making use of the z-transform

31 I will not open this melon 🍉 because it is very math heavy, but you can find an explanation here.

  • Standard MNIST classification (Treating each image as a sequence of 784 pixels): As reported: an S4 model of H=256 and 4-layers gets to around 99% accuracy right away.

  • MNIST next-pixel prediction. Like next token prediction in LMs but with pixel values, given an initial string (prompt) the model should complete the sequence: A model with H=512 and 6-layers ( 4 million params) achieved SOTA according to PapersWithCode. Here is an example of a prediction (sample) and the ground truth (The model was fed the white part, the first 300 pixels and inferred the rest):

  • QuickDraw next-pixel prediction. Same as before but more images representing more things than just digits: A model with H=256 and 4-layers generated “relatively” coherent completions (prompt was first 500 pixels)

  • Spoken digits classification Same as MNIST classification but the input is an audio wave of people reading digits from 0 to 9: The model achieved 97% accuracy from the raw soundwave.

  • Spoken digits generation The model gets prompted the begining of the audio wave and it has to continue it: The cherry-picked examples sound very good (even the tone of voice is matched: Check some examples here) H=512. The dataset has around 3k examples of around 6400 steps, at 8kHz sampling rate discretized into 256 classes with μ-law encoding:

orange=prompted part, blue=generated (or real) part.
  • Apparently it also performs well at movie-clip classification: ViS4mer paper. [PERSONAL NOTE]: I like the idea of combining the power of a transformer for fixed-size type of data (frame) with the power of S4 of analyzing time series (sequence of frames).

ViS4mer architecture

DSS

DSS shows that one can match S4’s level of performance by using a diagonal state matrix (i.e. dropping the low-rank correction). This massively simplifies the computation and makes its implementation much straight forward.

34 Diagonal State Spaces paper (March 2022)

35 This makes sense if you have read the explanation of S4, otherwise don’t worry too much about it.

Again, the model is very sensitive to parameter initialization. Which they use the eigenvalues of the normal part of: NPLR form of the HiPPO matrix (selecting the ones that ensure a stable system).

DSS convolution. Same as S4, but using diagonal A matrices.

S6 (Mamba)

Enough chit-chat! Let’s address the 🐘 in the room: Our model dynamics are constant. Look again at and , A,B,C,Δ (and thus A,B) remain fixed through the sequence! This property is known as Linear Time Invariance (LTI), and it is highly limiting in terms of modelling power. However, if we break LTI (by making these matrices variable), we loose the possibility to process the input time series in one step (we loose CNN aspect of the SSM, as the “kernel” is not fixed any more 📉).

36 How things change through time.

37 Mamba: Linear-Time Sequence Modeling with Selective State Spaces paper (Jan 2024)

Mamba does exactly that, adding two contributions: Selectivity (make B,C,Δ matrices depend on input xt) and, as its consequence, the “scan operation” (convolution substitute).

Combined together, they become the selective scan algorithm.

38 S6 because this is S4 + Selective Scan

Make of this whatever you want:

SSSSSS like the snakes

Selectivity

Mamba breaks SSMs LTI by allowing its parameters B,C,Δ be dependent on the current input. We now have:

39 They keep A fixed, they use the same one as in DSS.

(3)ht=Atht1+Btxtyt=Ctht

Where the matrices A,B of the discretized version are computed the same as explained before. But now, we have that:

40 Remember that C=C (seen in ).

Bt=LINEARB(xt)Ct=LINEARC(xt)Δt=LINEARΔ(xt)

In the paper, they show the differences from S4 and S6 like so:

They refer to this property as “selectivity”: Because now the model can be more selective about what gets included in the state. Some inputs xt might be more relevant than others for the given task, by making B dependent on x, we can more easily discard non-relevant stuff. This way we ensure that not all inputs are treated equally, VIP ones have bigger influence.

41 For instance, in audio transcription it might not be very important to remember that 5 minutes ago there was a long pause. (I talk a lot about transcription because it’s currently paying my bills).

42 Having Δ dependent on x has a similar effect, it is like choosing the step size based on how “interesting” an observation is.

Very simple, the convolution operation is designed for a convolution kernel which is always the same. Here the coefficients of the kernel change at each step depending on the input!

For instance, when processing y1 we have that the first element of the kernel is k1=C1TB1:

Whereas on the next convolution step, we now have that k1=C2TB2 😠

In conclusion, we cannot use the convolution operation, so we need to go back to the recurrent version of the model.

The scan operation

Alright, let’s forget about the CNN then. Let’s take a step back and look into the RNN. A naive way of implementing the computation (when the complete sentence is provided) would be to do something like this:

43 I’ll focus on the computation of h[t] since getting y[t] from h[t] is a simple matrix multiplication.

T = len(x)

h[0] = B_bar[0] * x[0]  # This is a vector
for t in range(T):
    h[t] = A_bar[t] * h[t-1] + B_bar[t] * x[t]

An (overcomplicated) representation of this computation would be this:

44 Please bare with me, it’ll make understanding the scan operation much simpler.

First we compute B_bar[t] * x[t] t. This is a tensor multiplication without previous time-step dependencies, which can easily be parallelized. Then, for the iterative part, at every step t we take the output of the previous step h[t-1], we multiply by A_bar and add B_bar[t] * x[t], which gives us h[t]. Notice that we need 8 steps to compute the output of a sequence of length 8.

This is a very efficient way to compute it in terms of number of operations. However, remember that the hardware (GPUs, TPUs) we use to train this type of models is optimized for high parallelization. And, ultimately, what we care about is time-to-complete the computation. Considering this, is there a way to compute this faster? Yes, there is! Taking inspiration from the parallel scan operation.

45 This thing has more names than a spanish telenovela character. It is referred to as: cumulative sum, or prefix sum, or inclusive scan or just scan

46 It is not the most work-efficient one, so it is not what Albert and Tri do, but I think it is useful to get the main idea behind this.

47 Notice that now h is a matrix of shape log_2(T) x T. I add a new dimension to h to store intermediate steps. In the end, we will use the last row.

There exist different implementations of this operation presenting varying properties, check wikipedia. Here I’m going to focus on the simplest one to grasp, just to get the gist of it. The idea is the following:

T = len(x)

h[0, :] = B_bar[:] * x[:]  # Broadcasted for all t
for i in range( log_2(T) ):
    for t in range(T) DO IN PARALLEL:
        if t < 2^i:
            h[i, t] = h[i-1, t]
        else:
            A_acc = A[t] * A[t - 1] * ... * A[t - 2^i + 1]
            h[i, t] =  h[i-1, t] + A_acc * h[i-1, t - 2^i]

Using the same diagram as before, we can now represent the computations as follows:

48 Not so stupid after all to represent it the way I did, is it?

The general idea is to compute the sequence in parts and iteratively combine them thanks to the associative property. Thus, we now only need log_2(T) steps to compute the output! 🥳 (if your machine has at least T processors).

It might be useful to see the “path” the computation follows to get a particular input. I acknowledge there is an interesting mess of lines:

As you can see, it depends on all previous inputs and the powers of A are correctly assigned. Notice I nsimplified the problem by assuming a constant A for a more clear explanation. The idea is the same with input-dependent At.

On top of that, they are very conscious of how to efficiently perform these operations in a GPU.

49 After all, these guys are the ones who came up with FlashAttention (May 2022)

Hardware-aware implementation of the scan operation. In this case, we have an input_dim = output_dim = 5, and a hidden_dim = 4. The idea is that they only materialize the hidden states in the most efficient level of memory (GPU SRAM). This avoids the slow operation of copying these values from SRAM into the HBM (this slow IO transfer speed is one of the main bottlenecks of current GPUs).

50 More about fused CUDA kernels here: .

Last remark: When training, they do not store the intermediate states in the forward pass. Instead, they recompute them again in the backward pass, which ends up being more efficient. This idea is also a common trick to reduce the memory footprint at train time.

51 This has many names, such as: gradient checkpointing, or rematerialization. f interested check out: Training Deep Nets with Sublinear Memory Cost, or this blog post

The optimization in Mamba is somehow in line of the FlashAttention (May 2022) proposed by Tri Dao et all. The idea is to substitute the PyTorch operations (or whichever deep learning framework is being used) by a custom CUDA kernel (aka Fused Kernel) which combines and performs them more efficiently:

Common flow of operations done in PyTorch. Image from here.

Flow of operations in a memory-aware fused CUDA kernel. Image from here.

For instance, in the case of FlashAttention, they substitute the omnipresent dot attention operations from the transformer by a single operation.

SOFTMAX(QKTdk)V

This (obviously) results into faster and more memory-efficient attention method. The approach not only combines the operations but also better utilizes the GPU memory allocation. For instance, an A100 GPU has:

  • 40GB (or 80 GB) of HBM (High-Bandwidth-Memory): Large but slow: This is implemented by stacking multiple DRAM (Dynamic Random Access Memory) dies allowing high parallel data transfers.

  • ~0.2MB x 108 processors of SRAM (Static Random Access Memory): Small but fast: The speed advantage of SRAM over DRAM comes from SRAM’s ability to hold a given data bit in a static state (on or off) as long as power is supplied. Moreover, it is more reliable than DRAM. DRAM must refresh its stored data bits many times per second to maintain the integrity of the data stored (making operations slower). However, SRAM presents a much higher cost-per-bit and and requires more physical space on the chip. Thus, it is usually reserved for operations where speed and reliability is critical (such as CACHE in CPUs or GPUs). .

52 Fancy way of saying: function that runs on GPU written in CUDA

The mamba block

Similar to a transformer block, they provide a “Mamba block”, which can be stacked to build larger models. It basically adds feature expansion, a CNN and a skip connection around the SSM layer like so:

If you check the official implementation. Specfically, the mamba block, you’ll see this important parameters: d_model (input and output dimensions), d_state (SSM hidden state dim), d_conv (conv-1d kernel size), expand (initial projection expansion factor). I referenced them in the diagram.

And that’s pretty much it 🙃

Epilogue

I’d just like to add that I found the line of research by Albert and Tri extremely inspiring. It’s super cool that such talented people thinking of new approaches to solve temporal series problems (instead of micro-optimzing the transformers). I like to see the results of combining ideas from so many fields: control theory from SSMs, function projections from 🦛, parallel computing and hardware optimization from 🐍

In a next post I’ll explore the relationship between the original transformer, the linear transformer, Mamba, and maybe some other RNN model. I’ll paying special attention to similarties in architecture and time & memory complexities at train/eval times..

53 As I was writting this (May 2024), Mamba2 came out and they do a lot of this work haha, so I guess I’ll be summarizing that :)