Linear Transformers, Mamba2, and many ramblings

I go through architectures used in sequence modelling: FFN, CNN, RNN, SSM, and Transformers along with many efficiency optimization attempts. I provide intuitive understanding of how they work, and analyze their strengths and weaknesses. All while paying special attention (pun intended) to their computational and memory complexities.
Beyond the transformer
Author

Oleguer Canal

Published

January 15, 2025

Apparently, mosquitoes 🦟 —when flying at night— rely on a source of light for navigation. Natural sources of light tend to be very far away 🌘. Thus, mosquitoes can follow a straight line by simply keeping a constant angle with them. However, artificial lights 💡 break this system: if one keeps a steady angle with a close object, one ends up going in circles…

That being said, today I want to explore various concepts around “sequence modeling”. Let’s start by formalizing it a bit. We’ll consider two sequences:

Where xt,ytRf t, and the order within {(x0,y0),(x1,y1),(x2.y2),} matters (each element might be dependent on the previous ones). This posts compares different ways of computing the mapping:

XY

I represent 1D tensors as rectangles. We can stack them to create 2D tensors X, Y.

I’ll try to provide intuitive understanding of how common models work, how they relate to one-another1 and their computational algorithmic complexity.

1 Mainly through the lenses of RNNs, because of their easy interpretability.

Focus

  • I’ll be mainly focusing on the computational bottleneck part of each model: the component performing attention, recursion, selectivity…
  • I’ll mostly ignore feed-forward, normalization and other embarrassingly-parallelizable steps.

Structure

Whenever it is interesting, I’ll split model formulation between:

  • Train-time: Where I assume the whole input sequence is available. Here we are usually interested in processing the whole sequence in a single step (as in not recurrent) to leverage the fact of having the complete sequence available.

  • Inference-time: I’ll focus in the case where the complete sequence sentence is not available. For instance: if processing a stream of data (e.g real-time audio transcription), or running the model in an autoregressive manner (e.g. next-token prediction).

Notation

  • T: Length of the sequence at train time
  • t: Length of the sequence at inference time t[0..T]
  • f: Feature dimension, usually both input-output dimension and hidden state size (if applicable, otherwise specified). It is assumed fT.

Inspiration

Narrative

There are so many details to comment that the big picture of the post might get a bit lost. Here is what I was going for:

Let’s tackle the problem XY, where X,Y are sequences.

  • What about yt=fω(xt) t? (Feed Forward)
    • Too simplistic, we are not using the fact that the data is ordered in a meaningful way.
  • What about yt=fω(xt,xt1,,xtk)t? (1D CNN)
    • We’d need to concatenate many of those layers for context to flow from the begining to end
  • What about yt,St=fω(xt,St1)t? (RNN)
    • In general this is not parallelizable (super slow training), the network forgets information, we have vanishing gradients problem.
  • What about making St the concatenation of all previously seen xt, i.e St={xt,xt1,,x0}? (Transformer)
    • This works great but consumes a lot of memory and FLOPs.
  • What about a more efficient approximation of the previous one?
    • Nice ideas but not as good as the softmax transformer.
  • Let’s go back to the RNN idea, isn’t there a way to train it more efficiently? (Mamba1)
    • Yes, if we add some structure.
  • Can’t we run it faster? (Mamba2)
    • Yes, if we add even more structure. Wait, we re-discovered a more general version of the linear transformer coming from the SSM branch!

In the following table you can see the memory and flops costs of the main models we’ll cover:

TLDR of this post (I still recommend to read it tho haha). I removed the O() notation for readability. I marked in bold the computationally-problematic things (mainly Softmax Transfomers). We assume Tf. “One-step-computable” refers to whether given x we can obtain y directly without iterating through the sequence (assuming enough parallel compute power is available). 2 3 4
Algorithm Train Inference One-step computable Global context
Memory FLOPs Memory FLOPs
FF Tf Tf2 f2 f2 🟢 🔴
1D CNN Tf Tf2 f2 f2 🟢 🔴
Standard RNN Tf Tf2 f2 f2 🔴 🟡
Naive Softmax
Transformer
T2 T2f tf tf 🟢 🟢
Flash-Attn Softmax
Transformer
Tf T2f tf tf 🟢 🟢
Mamba1 Tf2 Tf2 f2 f2 🟡 🟡
Linear Transformer
& SSD (Mamba2)
Tf Tf2 f2 f2 🟢 🟡

2 But in practice much faster, thanks to fused kernel.

3 Because information gets compressed

4 Heavily optimized with strong structure and the scan operation but still not one-step.

But why is sequence modelling a challenging problem?

5 Enterprise code repositories can easily be in the order of 100k lines of code, a single second of audio has 44k datapoints (if recorded at a standard 44kHz), the human genome has 3.1 billion base pairs.

Bad combo… Still, there are several smart methods to be able to overcome these limitations. Put on your dancing shoes 🩰 because the show is about to start!

Feed-forward (FF)

Alright let’s get this over with! What’s the easiest thing we can do? 🤔.

Given x,yRT×f we could map xtyt by just doing6:

6 I include this basic model to establish a computational lower bound and provide context for more sophisticated approaches.

yt=fω(xt)t

Each yt is computed only with its corresponding xt

At the simplest, fω can be a linear projection7: fω(x):=Ax where ARf×f. In this case, the “computational bill” at train-time becomes O(Tf2) FLOPS since we have to perform T matrix-vector multiplications of size f×ff. In terms of memory, allocating the input and output are the main bottleneck, thus we have a cost of O(Tf)8.

7 We can also add some non-linearity and compose multiple functions to increase modelling power.

8 In language modelling this is known as a 2-gram model, in which case Y=X[1:]+<eos> (next-token prediction)

This has a great computational appeal: It is extremely parallelizable at train-time, and inference can be done within constant time and memory. However, its limited modeling capabilities make it insufficient for most real-world applications: Mainly, it doesn’t leverage the sequential nature of the data! There is no information flow within elements of the sequence.

Still, layers as such play an important role on more complex models. They allow to non-linearly combine internal features, feature normalization, and store model knowledge9.

9 I will not spend more time with this as it is the part of the models I would not be bothered with Tip 1.

Convolutional Neural Network (1D-CNN)

Ok, so how can we do better? 🤷

Instead of mapping element-wise each input of the sequence to an output, we could map a fixed k-size sliding window of inputs to an output:

yt=fω(xt,xt1,,xtk)

Each yt is computed only with a fixed window ox xt,...xtk

This keeps most of the computational appeal from the previous idea while also allowing us to locally transfer information along x.


Leo is right.

Interestingly, we can sequentially compose multiple of this type of layer to propagate information through longer time-spans:

y=fωLfω2fω1(x)

With each successive convolution, there is a cumulative aggregation of local features, which captures information from larger, more general aspects of the input, finally yielding a global understanding of it. In particular, if we have an input of length T, we’d at least need L=Tk1 number of layers so that the whole input has an influence on the whole output. I.e. x0 is considered on yT.

Representation of how long it takes for x0 to have an influence on yT

This is very intuitive in 2D CNNs used for vision: The first layers extract the position of very simple features, such as edges. As the input advances through the network, these features get combined into more and more complex patterns. In the last layers, features become recognizable common shapes such as eyes, wheels, roofs… This idea was key in early/small computer vision tasks (e.g: AlexNet) and helped propel the ML field. Now, transformers (e.g. Vision Transformer) perform it much more effectively by directly allowing all-to-all input interactions at each layer. This means all inputs have influence on all outputs at each time-step (let’s forget about causality for now), removing the need of very deep networks. I provide more intuitive understanding of this in my post about deep dream.

This approach presents other disadvantages: like being time/position invariance (fixed kernel parameters ω across the whole sequence regardless of the input values), vanishing gradients for very deep networks (partly solved by residual connections).

Overall, not being designed for long-context information transmission makes them a bad candidate for the studied problem. As in FFN though, some modern sequence models include CNNs to locally combine features (see Mamba models) or to compress sequence temporal dimensionality (e.g. first layers of Whisper model for speech-to-text).

Recurrent Neural Network (RNN)

Hmmm, so how can we more effectively transfer information through time? 💭

We can store some internal state containing the relevant information that needs to be transmitted across time: St10 We would apply our model sequentially like so:

10 S as in “state” at time t. Also known as ht for “hidden state”.

St,yt=fω(St1,xt)

Each yt is compued from St1 and xt

This is a very generic representation and there exist multiple ways of implementing it (as we’ll later see). In Tip 3 I summarize a couple of influential -now classic- RNNs.

Here I present two of the most relevant functional forms for fω, now rarely used because of the limitations I explain later on in the post.

GRU

GRUs use two gates (an update gate and a reset gate) to regulate the flow of information. The update gate balances between carrying forward previous hidden states and incorporating new information, while the reset gate decides how much of the past context to forget.

LSTM

LSTMs utilize three gates (input, forget, and output gates) along with a dedicated cell state to effectively maintain and process long-range dependencies in sequential data. The gates control how much new information is added, how much old information is discarded, and how much of the current cell state is passed to the output.

Notice that LSTMs maintain two separate states: an internal cell state and a hidden state to better isolate and preserve long-range information. St in this case can be seen as the concatenation of both ct,ht. On the other hand GRUs combine these into a single hidden state and rely on fewer gates. This design makes GRUs simpler and faster to train, but LSTMs can sometimes capture longer dependencies more effectively due to the separate cell state.

In their classic form (constant St size and non-parallelizable sequence training) they presented several drawbacks which made them obsolete for big problems:

  • Non-parallelizable sequence training becomes prohibiting for long sequences. The parallelization power of GPUs is lost if the computation of St is blocked by the computation of St1. One can still parallelize along the batch dimension but weight updates are still too slow in comparison to the other methods.

  • The fixed state size St might be too small to compress all relevant information of the sequence, resulting in forgetting problems.

  • Vanishing gradient problems which arise from back-propagation-through time. In the backward pass, for each step t we compute the gradient as: gt=gt+1Jt. Where Jt is the Jacobian matrix of step t. Consequently, as usually Jt is contrastive11 its cumulative product decays or “vanishes” with sequence length. To the point of having near-zero effect after few iterations. This results in the network struggling to learn dependencies from earlier inputs.12.

11 I.e. it has eigenvalues |λ|<1.

12 Solutions include: LSTM (incorporate gating mechanism to allow for longer-range dependencies), ReLU activation (instead of Sigmoid or tanh, which have very small derivatives if values are far from 0), gradient clipping (to prevent them being to small or too large), or layer normalization (help stabilize gradients)

We’ll later see that we can re-work either the functional form of fω or the definition of St and bypass those limitations.

Transformer Decoder

Are we still doing the rethorical question thing? Yes. Ok, so if compressing information doesn’t work, what could we do instead?

We can simply consider the complete sequence for each guess! This solves all the performance issues, at higher memory and computation costs (obviously). This is how Transformers do it:

Train-time

Given the input XRT×f13, we linearly project it into key, query, values14: Q=LINEARωq(X)RT×fK=LINEARωk(X)RT×fV=LINEARωv(X)RT×f

13 I’ll focus on self-attention, and, for simplicity, I assume all projections are done into a space of f-dimensionality.

14 I provide interpretability of those in my Attention Mechanism Zoo post

And we then apply dot-product attention:

Attention(Q,K,V)=softmax(L(QKT)f)V

Where the softmax is applied row-wise. And LRT×T is a lower-diagonal matrix used for causality masking. This is the bill of naively implementing this:

Operation Memory FLOPs
Q,K,V projections Tf Tf2
Computing and allocating QKT T2 T2f
Masking & Softmax 115 T2
Values Tf T2f
Total O(T2) O(T2f)

15 As in no extra space is needed

16 My research crush.

Luckily Tri Dao16 & company introduced a way around it in 2022 Tip 4.

FlashAttention (May 2022) introduces two key ideas:

  • Less memory usage because of not materializing QK in memory.
  • Faster execution because of fused kernel.

The idea is to substitute the PyTorch operations (or whichever deep learning framework is being used) by a custom CUDA kernel17 (aka Fused Kernel) which combines and performs them more efficiently:

Common flow of operations done in PyTorch. Image from here.

Flow of operations in a memory-aware fused CUDA kernel. Image from here.

Consider: SOFTMAX(QKTdk)V

This (obviously) results into faster and more memory-efficient attention method. The approach not only combines the operations but also better utilizes the GPU memory allocation. For instance, an A100 GPU has:

  • 40GB (or 80 GB) of HBM (High-Bandwidth-Memory): Large but slow: This is implemented by stacking multiple DRAM (Dynamic Random Access Memory) dies allowing high parallel data transfers.

  • ~0.2MB x 108 processors of SRAM (Static Random Access Memory): Small but fast: The speed advantage of SRAM over DRAM comes from SRAM’s ability to hold a given data bit in a static state (on or off) as long as power is supplied. Moreover, it is more reliable than DRAM. DRAM must refresh its stored data bits many times per second to maintain the integrity of the data stored (making operations slower). However, SRAM presents a much higher cost-per-bit and and requires more physical space on the chip. Thus, it is usually reserved for operations where speed and reliability is critical (such as CACHE in CPUs or GPUs). .

17 Fancy way of saying: function that runs on GPU written in CUDA

Inference-time

Imagine we have cached18 K0:t1, V0:t1. Then, at time t with input xt, we only need to compute:

18 I’ll focus in the case where we use KV-caching

qt=LINEARωq(xt)R1×fkt=LINEARωk(xt)R1×fvt=LINEARωv(xt)R1×f

Then, using the cache:

K0:t=[K0:t1kt]Rt×f,V0:t=[V0:t1vt]Rt×f

We can then compute the new attention output vector:

Attentiont(Q,K,V)=softmax(qtK0:tTf)V0:t

1. We compute qt,kt,vt from xt.
2. How is qt related to all previous K’s? We do so by computing the dot-product with each of them (we use the cached K0..t1 keys). We then apply softmax for normalization and we call the result “attention vector”.
3. The final result is the weighted average of values according to the attention vector obtained in step 219 (we use the cached V0..t1 values).

19 If there was a strong relationship between a quey-key, its associated value will have a strong influence in yt.

Computationally the cost comes down to these steps:

Operation Memory FLOPs
q,k,v projections f f2
Computing and allocating qtK0:tT t tf
Softmax 1 t
Values f tf
Total O(t) O(tf)
Transformer Decoder through the RNN lenses

Notice we can see the KV-cache as the internal state St of the model. The particular thing about this model is that we have an internal state which grows linearly with the sequence: it takes O(ft) memory.

Always having access to all past tokens is a key characteristic of transformers (both good and bad):

  • They don’t forget (in contrast to fixed-size state of common RNNs).
  • Each layer has global context (in contrast to CNNs, whose layers have local context and need to be very deep in order to extract global input features)
  • They have a growing memory, making them unappropriated for long sequences.

Linear Transformers

Uff, can’t we do some kind of approximation which is almost-as-good but at a much-lower computational cost? 😖

We can try! The choice of the softmax function as a non-linearity / normalization in the attention mechanism might initially seem a bit arbitrary (and maybe it was). However, as we will explore in this section, it has proven to be crucial for the performance of transformers and extremely challenging to improve upon.

There have been many attempts to make the standard transformer architecture more computationally efficient (see Tip 520).

20 This could be a post on itself, here I just go over some interesting ideas.

Lowering compute complexity of the original transformer can yield many benefits: faster processing and ability to process larger context windows, for instance. Despite the merit of the approaches we’ll review, usually the benefits obtained, get undermined by the quality deterioration due to the approximations made.

Still, it is worth understanding the efforts made, some of them are quite neat,and might inspire future methods:

Mapping of some relevant attempts of reducing transformer complexity. paper

We can roughly categorize these attempts into these groups21:

  • Linformer leverages the empirical idea that the attention matrix is more-or-less low-rank (they look at the eigenvalue distribution). So they use the Johnson–Lindenstrauss lemma to approximate the attention matrix. Fanciness aside, this lemma states that: If we use a random projection matrix to project a set of points onto a lower dimension, the pairwise distances are approximately preserved. In practice they compress both keys and values along the time axis into a fixed dimension: KRT×fKRk×f and VRT×fVRk×f. Then, the attention matrix has a shape of QKTRT×k (linear in time).

    • Complexity: O(T) memory and time (ignoring f).
  • Nyströmformer also leverages the low-rank assumption of the attention matrix. It smartly uses the Nystrom method to approximate the attention matrix by using a subset of Q, K rows.

    • Complexity: O(T) memory and time (ignoring f).
  • Reformer They use locality-sensitive hashing to reduce the amount of dot-products. However, keys and query values need to be identical, which limits its modelling power and its usage for cross-attention tasks.

    • Complexity: O(TlogT) memory and time (ignoring f).
  • Sparse Transformer: Does a Sparse factorization of the attention matrix.

    • Complexity: O(TT) memory and time (ignoring f).
  • Big Bird: Applies global attention to a few tokens, combined with local attention and random connections for the rest to reduce the dimensionality of the attn matrix.

  • Transformers are RNNs: Use the reverse kernel trick to approximate the softmax attention. The following section expands on this idea.

    • Complexity: O(T) memory and time (ignoring f).

Here are some other relevant methods with interesting ideas, if I have time one of these days I’ll add a two-sentence explanation on them.

21 I recommend this post and this other post to start going into this 🐇 hole.

However, for the purposes of today’s blog, in this section I’ll focus ons Transformers are RNNs paper. Before we jump in, let’s first make sure we are on the same page on the kernel trick Tip 6.

Kernel function

Given x,yRn. We say that K:Rn×RnR is a kernel if there exists another function ϕ:RnRm which certifies:

K(x,y)=ϕ(x)ϕ(y)

We (humans) have found a few of these kernels. For instance, the Polynomial Kernel:

K(x,y)=(xTy+c)d

Where, for instance if we take d=2 we have that the projection function ϕ(x) is:

ϕ(x)=[xn2x122xnxn12xnx12xn1xn22x2x12cxn2cx1c]

Another typical one is the Gaussian RBF Kernel:

K(x,y)=exp(xy22σ2)

Here we have that the explicit feature mapping ϕ(x) is infinite-dimensional, but the kernel represents the inner product in this infinite-dimensional space, which can be very powerful.

The Kernel Trick

In Machine Learning we call “the kernel trick” the usage of a kernel function K to “simulate” the projection of data into a higher-dimensional space. Usually it is easier to linearly split datapoints in high dimensional spaces through nonlinear projections. However, explicitly projecting onto higher-dimensional spaces and calculating the dot-product (as we would need to do if we naively calculated ϕ(x)ϕ(y)) is computationally expensive.

In essence, it is a way to minimize computations.

Softmax as a kernel function

Consider a vanilla linear transformer22:

22 Linear because all dependencies are linear, we remove the softmax

Attention(Q,K,V)=QKTV

Where Q,K,VRT×f. Traditionally, (as I explained before) we’d multiply the matrices in this order:

Attention(Q,K,V)=(QKT)V

However, computing QKTRT×T and materializing the output requires O(T2) memory 😞

But now we can do better! Since we don’t need to run the softmax, using the associative property of matrix multiplication, we can compute KTV first:

Attention(Q,K,V)=Q(KTV)

Since (KTV)Rf×f it only requires O(f2) memory! Once we multiply by Q, it ends up being O(Tf) memory. For large sequences this massively reduces computational and memory costs. As hinted before though, this vanilla implementation doesn’t work as good as the softmax version of the transformer23.

23 More rigorous studies of this here

Ok, so that doesn’t work… What can we do about it though? Wouldn’t it’d be quite nice if softmax was a kernel function and there existed some ϕ (applied row-wise) such that:

softmax(QKT)=K(Q,K)=ϕ(Q)ϕ(K)T

We would then be able to write:

Attention(Q,K,V)=ϕ(Q)(ϕ(K)TV)

And get all the gains I explained before24.

24 Notice we are applying the kernel trick in a reversed way as usually 🤯

25 Intuitively: softmax (same as the RBF kernel) has an exponential, which has an infinite Taylor expansion.

So yeah… That would be quite nice, but sadly we don’t live in happyland where unicorns bring you lollipops as for lunch 🦄🍭. There is no free lunch! ϕ would need to be infinite-dimensional25, which is quite counter-productive in this case haha.

We can choose other ϕ functions though. In the paper Transformers are RNNs, they experimentally show that row-wise applying

ϕ(x)=elu(x)+1

has a performance on par with standard softmax transformers while significantly reducing computational and memory requirements.

Those are all very cool ideas! We’ll now break down the complexity at train/inference times, and try to see it through the RNN lenses 🕶️.

Train-time

As I explained in the previous section, we essentially just need to compute:

Attention(Q,K,V)=ϕ(Q)(ϕ(K)TV)

Computationally:

Operation Memory FLOPs
Q,K,V projections Tf Tf2
Apply ϕ 1 Tf
Comput ϕ(K)TV f2 Tf2
Multiply by ϕ(Q) Tf Tf2
Total O(Tf) O(Tf2)

That is very nice if we will always have all the sequence available both at train and inference times. Usually however, we’ll want to hide the future from current and past observations: Tip 8.

How does this work? We clearly can’t just multiply by a triangular matrix ϕ(Q)ϕ(K)T since we don’t materialize it now:

Attention(Q,K,V)=(L(QKT))V

Would force us to compute ϕ(Q)ϕ(K)T, defeating the purpose of ths approach. It is interesting to write it down tho, since it’ll come up later in the blog 😉.

For now, let’s take a step back, let’s define:

V:=softmax(QKTf)V

More generally, instead of the softmax we can have any similarity function sim, which, applied to a particular time (aka row):

Vt=τTsim(Qt,Kτ)VττTsim(Qt,Kτ)

Notice in the softmax transformer, we have that sim(q,k)=eqkf. An easy way to interpret this formulation is the following:

The new value (at time t) is a weighted average of all other values. These weights are given by they affinity between the query (at time t) and each of the keys.

It is easy to follow now that if we want to avoid future observations affecting our current value, we just need to limit the sum term up to t:

Vt=τtsim(Qt,Kτ)Vττtsim(Qt,Kτ)

In our currently-studied case though, since sim(q,k)=ϕ(q)ϕ(k)T, we have:

Vt=ϕ(Qt) τt(ϕ(Kτ)TVτ)ϕ(Qt) τtϕ(Kτ)T

Inference-time

Following the derivation presented in Tip 8, we have that:

Vt=ϕ(Qt) τt(ϕ(Kτ)TVτ)ϕ(Qt) τtϕ(Kτ)T

Here it is useful to define these matrices:

St:=τtϕ(Kτ)TVτ

Zt:=τtϕ(Kτ)T

Notice that StRf×f and ZtRf×1. We then have:

Vt=ϕ(Qt)Stϕ(Qt)Zt

What is cool about this is that both St and Zt can be computed in constant time from St1 and Zt1 respectively. We just need to add the projections from the last input:

St=St1+ϕ(Kt)TVt

Zt=Zt1+ϕ(Kt)T

Computationally, at each time-step we have O(f2) FLOPs and memory 🎉. More visually:

1. We compute qt,kt,vt from xt.
2. We compute the internal states St,Zt. Check for interpretability in Tip 9.
3. yt is the query times St normalized by the query times Zt.

I think of itthe following way: Each row of ktvt Rf×f is the value tensor weighted by the key tensor components:

How I imagine ktvt

In other words: Each component of the key vector is deciding how influential the value vector is to its corresponding row in the S matrix. k is deciding how to store v in S.

How S rows get upodated with a new input.

Each compnent of the query encodes how relevant its corresponding row of S is for the given input. This means qS returns the weighted average of stored values (which is already a weighetd accumulation of previous-step values as we’ve seen). This get’s normalized by how relevant each component is for the given input (query), times how relevant each row has been so far (Z). So, if we are querying something very strange (something that has not been much accumulated), we’ll have both a low vS vector and a low vZ number, so the normalization makes sense.

How S rows get upodated with a new input.

Oversimplified: k decides where to store each v26, and q is deciding how to retrieve it.

Interestingly, we can also see k as some kind of selectivity mechanism: If the model decides a particular input xt is not very relevant, k can be close to 0 and it doesn’t affect the hidden state S.

Example

I thought of an example that can help understand the role of each of the components of k: Imagine in a language-modelling task that the first component of k activates (presents a high value) whenever there is a proper noun. The first row of S will be storing proper noun information of the input text. Whenever the model needs to retrieve a proper noun, the query will have a high first component. This will yield a result vector whose components are mostly proper-noun information of the seen text.

26 Notice we can quite analogusly think of it column-wise instead of row-wise

Linear Transformer Decoder through the RNN lenses

We can then see St, Zt as the internal state of a RNN.

In S we combine the queries and values in a single matrix. Through time we keep adding stuff (which gets normalized by Z).

State Space Models (SSMs)

For a much more in-depth analysis of SSMs, check out post about Mamba models. I’ll focus on Mamba1 and Mamba2, SSMs which are:

  • Structured: Matrix A is forced to take a particular form (diagonal for Mamba1, scalar-times-identity for Mamba2).

  • Selective: Model parameters are dependent on the input for each time-step.

Mamba1 (S6)

Train-time

I’ll just focus on inference-time interpretability since train-time gets a bit deep and I already explored it in the Mamba post. The TLDR is that they develop a custom kuda kernel, they call scan operation, which efficiently computes the outputs, provided a complete input sequence.

Inference-time

Given xtR1 27, Mamba1’s SSM layer does the following operation:

27 Note: Usually xtRd and different SSM “heads” are used for each dimension. Everything I explain here is just broadcasted along the d-dimensions of the input (as if it was batched).

Bt=LINEARB(xt) RfCt=LINEARC(xt) RfΔt=LINEARΔ(xt) Rf

Remember they use a pre-fixed A matrix. The same one as introduced n the Diagonal State Spaces (DSS) paper. Applying the discretization step, we obtain:

AtRf×fBtRf×1CtR1×f

For the recurrence to be efficiently computable At is restricted to be diagonal. This is called structured matrix hence structured SSM. Therefore, we can store and manipulate the diagonal elements only as if it was a vector: AtRf

We then apply the recurrence relation:

ht=Atht1+Btxtyt=CtTht

Computationally:

*Remember though that this s just for one component, we need to do ths for each component of xtRd. Thus everything gets multiplied by d (where usually: df).
Operation Memory FLOPs
B,C,Δ projections f f
Discretize f f
Compute Btxt f f
Compute Atht1 f f
Compute CtTht f f
Total O(f)* O(f)*

Mamba1 inference visualized.
Mamba1 through the RNN lenses

It is very straight forward in this case 😂. Thinking in terms of gating mechanisms:

  • ht is the internal RNN state.
  • At controls what components of ht1 get forgotten.
  • Bt controls how xt gets added to ht.
  • Ct controls what components compose the output.

Mamba2 (SSD)

Mamba2’s SSM layer, presented within the SSD (State Space Duality) framework introduces two key changes wrt Mamba1:

  1. It further restricts the A matrix to be of type scalar-times-identity28:

28 Instead of diagonal.

At=atI

  1. It directly works with multi-dimensional input-output pairs29:

29 Instead of scalars.

x,yRT×d

Train-time

Since we have the sequence xRT available. Applying the same operations as before (linear projection + discretization) we can pre-compute:

aRTBRT×fCRT×f

Now the problem get’s simplified to the point that we can express the input-output mapping directly by a single matrix multiplication! To do so, let’s define a new matrix L:

L=[1000a1100a2a1a2101T1at2T1ataT11]

We then have that:

y=(LCBT)x

It is actually very simple to see given the recurrence defined by the discretized SSM problem and the constraints:

ht=atht1+Btxtyt=CtTht

Then:

h1=0y0=c0T(a00+b0x0)=c0Tb0 x0y1=c1T(a1h0+b1x1)=c1T(a1(c0Tb0 x0)+b1x1)y2=...

On its relationship with linear transformers

Hold your 🐎🐎! Didn’t we see something very similar already??

Yes! The linear transformer has almost the same form!! 🤯

y=(LQKT)V

The linear transformer is a particular case of the SSD framework where at=1 t instead of depending on the input x.

This can also be understood as a generalization of the positional encoding: as it is now input-dependent instead of fixed sinusoidal, rotary (RoPE) or whatever. Intuitively: the further away two points are: xt1,xt2, the lower tat between them will be, and the weaker dependence they’ll have. Of course depending on what is within them.

On memory requirements of Mamba2

Keep your shirt on 👕! Doesn’t this formulation force us to use the quadratic formulation of linear transformers?

Good question and yes! If we do a naive implementation of it. However, the matrix M:=(LCBT) is highly structured. In particular, it can efficiently be split into sub-blocks and parallelize matrix multiplications avoiding recomputations.

On its modelling power

Cool your jets 🛩️! Isn’t it counter-productive to restrict A to be scalar-times-diagonal?

This is something still being tested30. By having A to be scalar-times-identity we loose the hability to be selective of which components of the hidden states are erased. Given an input xt we proportionally keep all ht1 components (we either erase, fade or enhance them all equally). However, it looks like you can gain more by simply allowing higher x dimensionality. This is analogous to having more attention heads. Plus, by leveraging better matrix multiplications (more hardware optimized), we get much faster training than Mamba1.

30 Time of writing this is December 2024.

In theory:

  • Mamba2 training >> Mamba1 training
  • Mamba1 inference performance > Mamba2 inference performance

In practice:

  • Early results seem to indicate that the trade-offs taken in Mamba2 seem to be worth it: Mamba2 performs on par or better on early benchmarks31.

31 More testing needed.

Inference-time

I’ll not dive into inference logic of SDD since it is analogous to the already-covered linear transformer.

Epilogue

Alright, that got a bit out of hands (unsurprisingly) but it was useful for me to connect some ideas I had about sequence modelling. See you around!

Pasting it here again to wrap up stuff:

Let’s tackle the problem XY, where X,Y are sequences.

  • What about yt=fω(xt) t? (Feed Forward)
    • Too simplistic, we are not using the fact that the data is ordered in a meaningful way.
  • What about yt=fω(xt,xt1,,xtk)t? (1D CNN)
    • We’d need to concatenate many of those layers for context to flow from the begining to end
  • What about yt,St=fω(xt,St1)t? (RNN)
    • In general this is not parallelizable (super slow training), the network forgets information, we have vanishing gradients problem.
  • What about making St the concatenation of all previously seen xt, i.e St={xt,xt1,,x0}? (Transformer)
    • This works great but consumes a lot of memory and FLOPs.
  • What about a more efficient approximation of the previous one?
    • Nice ideas but not as good as the softmax transformer.
  • Let’s go back to the RNN idea, isn’t there a way to train it more efficiently? (Mamba1)
    • Yes, if we add some structure.
  • Can’t we run it faster? (Mamba2)
    • Yes, if we add even more structure. Wait, we re-discovered a more general version of the linear transformer coming from the SSM branch!

The end.