Mamba Explained

Mamba Explained: The State Space Model taking on Transformers

Mamba Explained: 挑战 Transformer 的状态空间模型

Right now, AI is eating the world. And by AI, I mean Transformers. Practically all the big breakthroughs in AI over the last few years are due to Transformers. Mamba, however, is one of an alternative class of models called State Space Models (SSMs). Importantly, for the first time, Mamba promises similar performance (and crucially similar scaling laws) as the Transformer whilst being feasible at long sequence lengths (say 1 million tokens). To achieve this long context, the Mamba authors remove the “quadratic bottleneck” in the Attention Mechanism. Mamba also runs fast - like “up to 5x faster than Transformer fast”.

当下,人工智能正在重塑世界。我所指的 AI,实际上就是 Transformer。过去几年里,人工智能领域几乎所有的重大突破都归功于 Transformer。然而,Mamba 属于另一类被称为状态空间模型(SSMs)的替代架构。重要的是,Mamba 首次承诺在实现与 Transformer 相似性能(以及至关重要的相似缩放定律)的同时,能够处理超长序列(例如 100 万个 token)。为了实现这种长上下文,Mamba 的作者移除了注意力机制中的“二次方瓶颈”。此外,Mamba 的运行速度极快——甚至比 Transformer 快 5 倍。

Gu and Dao, the Mamba authors write: Mamba enjoys fast inference and linear scaling in sequence length, and its performance improves on real data up to million-length sequences. As a general sequence model backbone, Mamba achieves state-of-the-art performance across several modalities such as language, audio, and genomics. On language modelling, our Mamba-3B model outperforms Transformers of the same size and matches Transformers twice its size, both in pretraining and downstream evaluation.

Mamba 的作者 Gu 和 Dao 写道:“Mamba 具备快速推理和序列长度线性缩放的特性,其性能在处理长达百万长度的真实数据序列时仍能持续提升。作为一种通用的序列模型骨干,Mamba 在语言、音频和基因组学等多个模态上均达到了最先进的性能。在语言建模方面,我们的 Mamba-3B 模型在预训练和下游评估中,均优于同等规模的 Transformer,并能媲美两倍于其规模的 Transformer。”

Here we’ll discuss: The advantages (and disadvantages) of Mamba (🐍) vs Transformers (🤖), Analogies and intuitions for thinking about Mamba, and What Mamba means for Interpretability, AI Safety and Applications.

在此,我们将讨论:Mamba (🐍) 与 Transformer (🤖) 的优缺点对比、理解 Mamba 的类比与直觉,以及 Mamba 对可解释性、AI 安全和应用意味着什么。

Problems with Transformers - Maybe Attention Isn’t All You Need

Transformer 的问题——也许“注意力”并非你所需的一切

We’re very much in the Transformer-era of history. ML used to be about detecting cats and dogs. Now, with Transformers, we’re generating human-like poetry, coding better than the median competitive programmer, and solving the protein folding problem. But Transformers have one core problem. In a transformer, every token can look back at every previous token when making predictions. For this lookback, we cache detailed information about each token in the so-called KV cache.

我们正处于历史上的“Transformer 时代”。机器学习过去主要用于识别猫和狗,而现在,借助 Transformer,我们能够生成类人的诗歌,编写出超越普通竞赛选手的代码,并解决蛋白质折叠问题。但 Transformer 有一个核心问题:在 Transformer 中,每个 token 在进行预测时都可以回顾之前的所有 token。为了实现这种回顾,我们需要将每个 token 的详细信息缓存到所谓的 KV Cache 中。

This pairwise communication means a forward pass is O(n²) time complexity in training (the dreaded quadratic bottleneck), and each new token generated autoregressively takes O(n) time. In other words, as the context size increases, the model gets slower. To add insult to injury, storing this key-value (KV) cache requires O(n) space. Consequently, the dreaded CUDA out-of-memory (OOM) error becomes a significant threat as the memory footprint expands. If space were the only concern, we might consider adding more GPUs; however, with latency increasing quadratically, simply adding more compute might not be a viable solution.

这种两两通信意味着训练时的前向传播具有 O(n²) 的时间复杂度(即令人头疼的二次方瓶颈),且自回归生成的每个新 token 都需要 O(n) 的时间。换句话说,随着上下文长度的增加,模型会变得越来越慢。雪上加霜的是,存储这些键值(KV)缓存需要 O(n) 的空间。因此,随着内存占用增加,可怕的 CUDA 显存溢出(OOM)错误成为重大威胁。如果空间是唯一的问题,我们或许可以考虑增加 GPU;然而,由于延迟呈二次方增长,单纯增加算力可能并非可行的解决方案。

Foundation Model Backbones

基础模型骨干

Fundamentally, all good ML architecture backbones have components for two important operations: Communication between tokens and Computation within a token. In transformers, this is Attention (communication) and MLPs (computation). We improve transformers by optimising these two operations. We would like to substitute the Attention component with an alternative mechanism for facilitating inter-token communication. Specifically, Mamba employs a Control Theory-inspired State Space Model, or SSM, for Communication purposes while retaining Multilayer Perceptron (MLP)-style projections for Computation.

从根本上讲,所有优秀的机器学习架构骨干都包含两个重要操作的组件:token 间的通信和 token 内部的计算。在 Transformer 中,这分别对应注意力机制(通信)和 MLP(计算)。我们通过优化这两个操作来改进 Transformer。我们希望用另一种机制来替代注意力组件,以促进 token 间的通信。具体而言,Mamba 采用了一种受控制理论启发的状态空间模型(SSM)来进行通信,同时保留了多层感知机(MLP)风格的投影来进行计算。

Motivating Mamba - A Throwback to Temple Run

Mamba 的动机——回顾《神庙逃亡》

Imagine we’re building a Temple Run agent. It chooses if the runner should move left or right at any time. To successfully pick the correct direction, we need information about our surroundings. Let’s call the collection of relevant information the state. Here the state likely includes your current position and velocity, the position of the nearest obstacle, weather conditions, etc.

想象一下我们正在构建一个《神庙逃亡》的智能体。它需要随时决定奔跑者是向左还是向右移动。为了成功选择正确的方向,我们需要关于周围环境的信息。我们将这些相关信息的集合称为“状态”。在这里,状态可能包括你当前的位置和速度、最近障碍物的位置、天气状况等。

Claim 1: if you know the current state of the world and how the world is evolving, then you can use this to determine the direction to move. Note that you don’t need to look at the whole screen all the time. You can figure out what will happen to most of the screen by noting that as you run, the obstacles move down the screen. You only need to look at the top of the screen to understand the new information and then simulate the rest.

主张 1:如果你了解世界的当前状态以及世界的演变方式,那么你就可以利用这些信息来决定移动方向。请注意,你不需要一直盯着整个屏幕。你可以通过观察到随着奔跑障碍物向下移动,从而推断出屏幕大部分区域的变化。你只需要观察屏幕顶部以获取新信息,然后模拟其余部分即可。

This lends itself to a natural formulation. Let h be the hidden state, relevant knowledge about the world. Also let x be the input, the observation that you get each time. h’ then represents the derivative of the hidden state, i.e. how the state is evolving. We’re trying to predict y, the optimal next move (right or left). Now, Claim 1 states that from the hidden state h, h’, and the new observation x, you can figure out y.

这自然引出了一个公式。令 h 为隐藏状态,即关于世界的相关知识;令 x 为输入,即你每次获得的观察结果。h’ 则代表隐藏状态的导数,即状态是如何演变的。我们试图预测 y,即最优的下一步动作(向左或向右)。现在,主张 1 指出,根据隐藏状态 h、h’ 以及新的观察结果 x,你就可以推导出 y。