Apparently, mosquitoes 🦟 —when flying at night— rely on a source of light for navigation. Natural sources of light tend to be very far away 🌘. Thus, mosquitoes can follow a straight line by simply keeping a constant angle with them. However, artificial lights 💡 break this system: if one keeps a steady angle with a close object, one ends up going in circles…
That being said, today I want to explore various concepts around “sequence modeling”. Let’s start by formalizing it a bit. We’ll consider two sequences:
- Input sequence:
- Output sequence:
Where
I’ll try to provide intuitive understanding of how common models work, how they relate to one-another1 and their computational algorithmic complexity.
1 Mainly through the lenses of RNNs, because of their easy interpretability.
Focus
- I’ll be mainly focusing on the computational bottleneck part of each model: the component performing attention, recursion, selectivity…
- I’ll mostly ignore feed-forward, normalization and other embarrassingly-parallelizable steps.
Structure
Whenever it is interesting, I’ll split model formulation between:
Train-time: Where I assume the whole input sequence is available. Here we are usually interested in processing the whole sequence in a single step (as in not recurrent) to leverage the fact of having the complete sequence available.
Inference-time: I’ll focus in the case where the complete sequence sentence is not available. For instance: if processing a stream of data (e.g real-time audio transcription), or running the model in an autoregressive manner (e.g. next-token prediction).
Notation
: Length of the sequence at train time : Length of the sequence at inference time : Feature dimension, usually both input-output dimension and hidden state size (if applicable, otherwise specified). It is assumed .
Inspiration
- I’ll talk about a lot of papers (specially in the optimization attempts of transformers), but the three main inspiration sources are:
Narrative
There are so many details to comment that the big picture of the post might get a bit lost. Here is what I was going for:
Let’s tackle the problem
- What about
? (Feed Forward)- Too simplistic, we are not using the fact that the data is ordered in a meaningful way.
- What about
? (1D CNN)- We’d need to concatenate many of those layers for context to flow from the begining to end
- What about
? (RNN)- In general this is not parallelizable (super slow training), the network forgets information, we have vanishing gradients problem.
- What about making
the concatenation of all previously seen , i.e ? (Transformer)- This works great but consumes a lot of memory and FLOPs.
- What about a more efficient approximation of the previous one?
- Nice ideas but not as good as the softmax transformer.
- Let’s go back to the RNN idea, isn’t there a way to train it more efficiently? (Mamba1)
- Yes, if we add some structure.
- Can’t we run it faster? (Mamba2)
- Yes, if we add even more structure. Wait, we re-discovered a more general version of the linear transformer coming from the SSM branch!
In the following table you can see the memory and flops costs of the main models we’ll cover:
Algorithm | Train | Inference | One-step computable | Global context | ||
---|---|---|---|---|---|---|
Memory | FLOPs | Memory | FLOPs | |||
FF | 🟢 | 🔴 | ||||
1D CNN | 🟢 | 🔴 | ||||
Standard RNN | 🔴 | 🟡 |
||||
Naive Softmax Transformer |
🟢 | 🟢 | ||||
Flash-Attn Softmax Transformer |
🟢 | 🟢 | ||||
Mamba1 | 🟡 |
🟡 |
||||
Linear Transformer & SSD (Mamba2) |
🟢 | 🟡 |
2 But in practice much faster, thanks to fused kernel.
3 Because information gets compressed
4 Heavily optimized with strong structure and the scan operation but still not one-step.
But why is sequence modelling a challenging problem?
High dimensionality: Sequences can be very long5 which posits memory issues.
Iterative nature: Since the ordering of the datapoints is relevant, a lot of algorithms cannot be parallelized without heavy memory and computational drawbacks, or strong forced structure. It doesn’t matter how nice your GPU is if you need to process one-element-at-a-time your sequence.
5 Enterprise code repositories can easily be in the order of 100k lines of code, a single second of audio has 44k datapoints (if recorded at a standard 44kHz), the human genome has 3.1 billion base pairs.
Bad combo… Still, there are several smart methods to be able to overcome these limitations. Put on your dancing shoes 🩰 because the show is about to start!
Feed-forward (FF)
Alright let’s get this over with! What’s the easiest thing we can do? 🤔.
Given
6 I include this basic model to establish a computational lower bound and provide context for more sophisticated approaches.
At the simplest,
7 We can also add some non-linearity and compose multiple functions to increase modelling power.
8 In language modelling this is known as a 2-gram model, in which case
This has a great computational appeal: It is extremely parallelizable at train-time, and inference can be done within constant time and memory. However, its limited modeling capabilities make it insufficient for most real-world applications: Mainly, it doesn’t leverage the sequential nature of the data! There is no information flow within elements of the sequence.
Still, layers as such play an important role on more complex models. They allow to non-linearly combine internal features, feature normalization, and store model knowledge9.
9 I will not spend more time with this as it is the part of the models I would not be bothered with Tip 1.
Convolutional Neural Network (1D-CNN)
Ok, so how can we do better? 🤷
Instead of mapping element-wise each input of the sequence to an output, we could map a fixed
This keeps most of the computational appeal from the previous idea while also allowing us to locally transfer information along
Interestingly, we can sequentially compose multiple of this type of layer to propagate information through longer time-spans:
With each successive convolution, there is a cumulative aggregation of local features, which captures information from larger, more general aspects of the input, finally yielding a global understanding of it. In particular, if we have an input of length T, we’d at least need
This is very intuitive in 2D CNNs used for vision: The first layers extract the position of very simple features, such as edges. As the input advances through the network, these features get combined into more and more complex patterns. In the last layers, features become recognizable common shapes such as eyes, wheels, roofs… This idea was key in early/small computer vision tasks (e.g: AlexNet) and helped propel the ML field. Now, transformers (e.g. Vision Transformer) perform it much more effectively by directly allowing all-to-all input interactions at each layer. This means all inputs have influence on all outputs at each time-step (let’s forget about causality for now), removing the need of very deep networks. I provide more intuitive understanding of this in my post about deep dream.
This approach presents other disadvantages: like being time/position invariance (fixed kernel parameters
Overall, not being designed for long-context information transmission makes them a bad candidate for the studied problem. As in FFN though, some modern sequence models include CNNs to locally combine features (see Mamba models) or to compress sequence temporal dimensionality (e.g. first layers of Whisper model for speech-to-text).
Recurrent Neural Network (RNN)
Hmmm, so how can we more effectively transfer information through time? 💭
We can store some internal state containing the relevant information that needs to be transmitted across time:
10
This is a very generic representation and there exist multiple ways of implementing it (as we’ll later see). In Tip 3 I summarize a couple of influential -now classic- RNNs.
Here I present two of the most relevant functional forms for
GRU
LSTM
Notice that LSTMs maintain two separate states: an internal cell state and a hidden state to better isolate and preserve long-range information.
In their classic form (constant
Non-parallelizable sequence training becomes prohibiting for long sequences. The parallelization power of GPUs is lost if the computation of
is blocked by the computation of . One can still parallelize along the batch dimension but weight updates are still too slow in comparison to the other methods.The fixed state size
might be too small to compress all relevant information of the sequence, resulting in forgetting problems.Vanishing gradient problems which arise from back-propagation-through time. In the backward pass, for each step
we compute the gradient as: . Where is the Jacobian matrix of step . Consequently, as usually is contrastive11 its cumulative product decays or “vanishes” with sequence length. To the point of having near-zero effect after few iterations. This results in the network struggling to learn dependencies from earlier inputs.12.
11 I.e. it has eigenvalues
12 Solutions include: LSTM (incorporate gating mechanism to allow for longer-range dependencies), ReLU activation (instead of Sigmoid or tanh, which have very small derivatives if values are far from 0), gradient clipping (to prevent them being to small or too large), or layer normalization (help stabilize gradients)
We’ll later see that we can re-work either the functional form of
Transformer Decoder
Are we still doing the rethorical question thing? Yes. Ok, so if compressing information doesn’t work, what could we do instead?
We can simply consider the complete sequence for each guess! This solves all the performance issues, at higher memory and computation costs (obviously). This is how Transformers do it:
Train-time
Given the input key
, query
, values
14:
13 I’ll focus on self-attention, and, for simplicity, I assume all projections are done into a space of
14 I provide interpretability of those in my Attention Mechanism Zoo post
And we then apply dot-product attention:
Where the
Operation | Memory | FLOPs |
---|---|---|
Q,K,V projections | ||
Computing and allocating |
||
Masking & Softmax | ||
Values | ||
Total |
15 As in no extra space is needed
16 My research crush.
Luckily Tri Dao16 & company introduced a way around it in 2022 Tip 4.
FlashAttention (May 2022) introduces two key ideas:
- Less memory usage because of not materializing QK in memory.
- Faster execution because of fused kernel.
The idea is to substitute the PyTorch operations (or whichever deep learning framework is being used) by a custom CUDA kernel17 (aka Fused Kernel) which combines and performs them more efficiently:
Consider:
This (obviously) results into faster and more memory-efficient attention method. The approach not only combines the operations but also better utilizes the GPU memory allocation. For instance, an A100 GPU has:
40GB (or 80 GB) of HBM (High-Bandwidth-Memory): Large but slow: This is implemented by stacking multiple DRAM (Dynamic Random Access Memory) dies allowing high parallel data transfers.
~0.2MB x 108 processors of SRAM (Static Random Access Memory): Small but fast: The speed advantage of SRAM over DRAM comes from SRAM’s ability to hold a given data bit in a static state (on or off) as long as power is supplied. Moreover, it is more reliable than DRAM. DRAM must refresh its stored data bits many times per second to maintain the integrity of the data stored (making operations slower). However, SRAM presents a much higher cost-per-bit and and requires more physical space on the chip. Thus, it is usually reserved for operations where speed and reliability is critical (such as CACHE in CPUs or GPUs). .
17 Fancy way of saying: function that runs on GPU written in CUDA
Inference-time
Imagine we have cached18
18 I’ll focus in the case where we use KV-caching
Then, using the cache:
We can then compute the new attention output vector:
2. How is
3. The final result is the weighted average of values according to the attention vector obtained in step 219 (we use the cached
19 If there was a strong relationship between a quey-key, its associated value will have a strong influence in
Computationally the cost comes down to these steps:
Operation | Memory | FLOPs |
---|---|---|
q,k,v projections | ||
Computing and allocating |
||
Softmax | 1 | |
Values | ||
Total |
Notice we can see the KV-cache as the internal state
Always having access to all past tokens is a key characteristic of transformers (both good and bad):
- They don’t forget (in contrast to fixed-size state of common RNNs).
- Each layer has global context (in contrast to CNNs, whose layers have local context and need to be very deep in order to extract global input features)
- They have a growing memory, making them unappropriated for long sequences.
Linear Transformers
Uff, can’t we do some kind of approximation which is almost-as-good but at a much-lower computational cost? 😖
We can try! The choice of the
There have been many attempts to make the standard transformer architecture more computationally efficient (see Tip 520).
20 This could be a post on itself, here I just go over some interesting ideas.
Lowering compute complexity of the original transformer can yield many benefits: faster processing and ability to process larger context windows, for instance. Despite the merit of the approaches we’ll review, usually the benefits obtained, get undermined by the quality deterioration due to the approximations made.
Still, it is worth understanding the efforts made, some of them are quite neat,and might inspire future methods:
We can roughly categorize these attempts into these groups21:
Linformer leverages the empirical idea that the attention matrix is more-or-less low-rank (they look at the eigenvalue distribution). So they use the Johnson–Lindenstrauss lemma to approximate the attention matrix. Fanciness aside, this lemma states that: If we use a random projection matrix to project a set of points onto a lower dimension, the pairwise distances are approximately preserved. In practice they compress both
keys
andvalues
along the time axis into a fixed dimension: and . Then, the attention matrix has a shape of (linear in time).- Complexity:
memory and time (ignoring ).
- Complexity:
Nyströmformer also leverages the low-rank assumption of the attention matrix. It smartly uses the Nystrom method to approximate the attention matrix by using a subset of Q, K rows.
- Complexity:
memory and time (ignoring ).
- Complexity:
Reformer They use locality-sensitive hashing to reduce the amount of dot-products. However,
keys
andquery
values need to be identical, which limits its modelling power and its usage for cross-attention tasks.- Complexity:
memory and time (ignoring ).
- Complexity:
Sparse Transformer: Does a Sparse factorization of the attention matrix.
- Complexity:
memory and time (ignoring ).
- Complexity:
Big Bird: Applies global attention to a few tokens, combined with local attention and random connections for the rest to reduce the dimensionality of the attn matrix.
Transformers are RNNs: Use the reverse kernel trick to approximate the softmax attention. The following section expands on this idea.
- Complexity:
memory and time (ignoring ).
- Complexity:
Here are some other relevant methods with interesting ideas, if I have time one of these days I’ll add a two-sentence explanation on them.
21 I recommend this post and this other post to start going into this 🐇 hole.
However, for the purposes of today’s blog, in this section I’ll focus ons Transformers are RNNs paper. Before we jump in, let’s first make sure we are on the same page on the kernel trick Tip 6.
Kernel function
Given
We (humans) have found a few of these kernels. For instance, the Polynomial Kernel:
Where, for instance if we take
Another typical one is the Gaussian RBF Kernel:
Here we have that the explicit feature mapping
The Kernel Trick
In Machine Learning we call “the kernel trick” the usage of a kernel function
In essence, it is a way to minimize computations.
Softmax as a kernel function
Consider a vanilla linear transformer22:
22 Linear because all dependencies are linear, we remove the
Where
However, computing
But now we can do better! Since we don’t need to run the
Since
23 More rigorous studies of this here
Ok, so that doesn’t work… What can we do about it though? Wouldn’t it’d be quite nice if
We would then be able to write:
And get all the gains I explained before24.
24 Notice we are applying the kernel trick in a reversed way as usually 🤯
25 Intuitively:
So yeah… That would be quite nice, but sadly we don’t live in happyland where unicorns bring you lollipops as for lunch 🦄🍭. There is no free lunch!
We can choose other
has a performance on par with standard softmax transformers while significantly reducing computational and memory requirements.
Those are all very cool ideas! We’ll now break down the complexity at train/inference times, and try to see it through the RNN lenses 🕶️.
Train-time
As I explained in the previous section, we essentially just need to compute:
Computationally:
Operation | Memory | FLOPs |
---|---|---|
Q,K,V projections | ||
Apply |
1 | |
Comput |
||
Multiply by |
||
Total |
That is very nice if we will always have all the sequence available both at train and inference times. Usually however, we’ll want to hide the future from current and past observations: Tip 8.
How does this work? We clearly can’t just multiply by a triangular matrix
Would force us to compute
For now, let’s take a step back, let’s define:
More generally, instead of the
Notice in the softmax transformer, we have that
The new value (at time t) is a weighted average of all other values. These weights are given by they affinity between the query (at time t) and each of the keys.
It is easy to follow now that if we want to avoid future observations affecting our current value, we just need to limit the sum term up to
In our currently-studied case though, since
Inference-time
Following the derivation presented in Tip 8, we have that:
Here it is useful to define these matrices:
Notice that
What is cool about this is that both
Computationally, at each time-step we have
2. We compute the internal states
3.
I think of itthe following way: Each row of
In other words: Each component of the key vector is deciding how influential the value vector is to its corresponding row in the
Each compnent of the query encodes how relevant its corresponding row of
Oversimplified:
Interestingly, we can also see
Example
I thought of an example that can help understand the role of each of the components of
26 Notice we can quite analogusly think of it column-wise instead of row-wise
We can then see
In
State Space Models (SSMs)
For a much more in-depth analysis of SSMs, check out post about Mamba models. I’ll focus on Mamba1 and Mamba2, SSMs which are:
Structured: Matrix
is forced to take a particular form (diagonal for Mamba1, scalar-times-identity for Mamba2).Selective: Model parameters are dependent on the input for each time-step.
Mamba1 (S6)
Train-time
I’ll just focus on inference-time interpretability since train-time gets a bit deep and I already explored it in the Mamba post. The TLDR is that they develop a custom kuda kernel, they call scan operation, which efficiently computes the outputs, provided a complete input sequence.
Inference-time
Given
27 Note: Usually
Remember they use a pre-fixed
For the recurrence to be efficiently computable
We then apply the recurrence relation:
Computationally:
Operation | Memory | FLOPs |
---|---|---|
B,C, |
||
Discretize | ||
Compute |
||
Compute |
||
Compute |
||
Total |
It is very straight forward in this case 😂. Thinking in terms of gating mechanisms:
is the internal RNN state. controls what components of get forgotten. controls how gets added to . controls what components compose the output.
Mamba2 (SSD)
Mamba2’s SSM layer, presented within the SSD (State Space Duality) framework introduces two key changes wrt Mamba1:
- It further restricts the
matrix to be of type scalar-times-identity28:
28 Instead of diagonal.
- It directly works with multi-dimensional input-output pairs29:
29 Instead of scalars.
Train-time
Since we have the sequence
Now the problem get’s simplified to the point that we can express the input-output mapping directly by a single matrix multiplication! To do so, let’s define a new matrix
We then have that:
It is actually very simple to see given the recurrence defined by the discretized SSM problem and the constraints:
Then:
On its relationship with linear transformers
Hold your 🐎🐎! Didn’t we see something very similar already??
Yes! The linear transformer has almost the same form!! 🤯
The linear transformer is a particular case of the SSD framework where
This can also be understood as a generalization of the positional encoding: as it is now input-dependent instead of fixed sinusoidal, rotary (RoPE) or whatever. Intuitively: the further away two points are:
On memory requirements of Mamba2
Keep your shirt on 👕! Doesn’t this formulation force us to use the quadratic formulation of linear transformers?
Good question and yes! If we do a naive implementation of it. However, the matrix
On its modelling power
Cool your jets 🛩️! Isn’t it counter-productive to restrict
This is something still being tested30. By having
30 Time of writing this is December 2024.
In theory:
- Mamba2 training >> Mamba1 training
- Mamba1 inference performance > Mamba2 inference performance
In practice:
- Early results seem to indicate that the trade-offs taken in Mamba2 seem to be worth it: Mamba2 performs on par or better on early benchmarks31.
31 More testing needed.
Inference-time
I’ll not dive into inference logic of SDD since it is analogous to the already-covered linear transformer.
Epilogue
Alright, that got a bit out of hands (unsurprisingly) but it was useful for me to connect some ideas I had about sequence modelling. See you around!
Pasting it here again to wrap up stuff:
Let’s tackle the problem
- What about
? (Feed Forward)- Too simplistic, we are not using the fact that the data is ordered in a meaningful way.
- What about
? (1D CNN)- We’d need to concatenate many of those layers for context to flow from the begining to end
- What about
? (RNN)- In general this is not parallelizable (super slow training), the network forgets information, we have vanishing gradients problem.
- What about making
the concatenation of all previously seen , i.e ? (Transformer)- This works great but consumes a lot of memory and FLOPs.
- What about a more efficient approximation of the previous one?
- Nice ideas but not as good as the softmax transformer.
- Let’s go back to the RNN idea, isn’t there a way to train it more efficiently? (Mamba1)
- Yes, if we add some structure.
- Can’t we run it faster? (Mamba2)
- Yes, if we add even more structure. Wait, we re-discovered a more general version of the linear transformer coming from the SSM branch!
The end.