LLM(十一)——Mamba
一、SSM(State Space Model)
1.1 State Space
下图中每个小框代表迷宫中的一个位置,并有某些隐式的信息,例如你距离出口有多远:

而上述迷宫可以简化建模为一个“状态空间表示state space representation”,每一个小框显示:
- 当前所在位置(当前状态Current State)
- 下一步可以前往哪里(未来可能的状态Possible Future States)
- 以及哪些变化会将你带到下一个状态(向右或向左)
而描述状态的变量(在上述示例中为 X 和 Y 坐标以及到出口的距离)可以表示为“状态向量state vectors”。

在语言模型中的嵌入或向量也经常用于描述输入序列的“状态”。例如,当前位置的向量(状态向量)可能看起来有点像这样:

1.2 SSM
SSM
是用于描述状态表示并根据某些输入预测其下一个状态可能是什么的模型,在
- 映射输入序列
,比如在迷宫中向左和向下移动 - 到潜在状态表示
,比如距离出口距离和 x/y 坐标 - 并导出预测输出序列
,比如再次向左移动以更快到达出口
然而,SSM不使用离散序列(如向左移动一次),而是将连续序列作为输入并预测输出序列。

SSM 假设动态系统(例如在 3D
空间中移动的物体)可以通过两个方程从其在时间
- RNN的循环结构:
和上面的第一个方程非常类似,都是通过上一个隐藏状态和当前输入综合得到当前的隐藏状态,只是两个权重 换成了 两个系数,且去掉了非线性的激活函数 。 就是存储着之前所有历史信息的浓缩精华(可以通过一系列系数组成的矩阵表示之),以基于 更新下一个时刻的空间状态hidden state
通过求解这些方程,假设可以揭示统计原理,以根据观察到的数据(输入序列和先前状态)预测系统的状态。
1.2.1 状态方程与输出方程
SSM的目标是找到状态表示

这两个方程是SSM的核心。矩阵
状态方程
矩阵

换言之,
输出方程
描述了状态如何转换为输出(通过矩阵
01
1.2.2 SSM架构
上述两个方程可以统一成以下架构:

下面通过逐步拆解,以了解这些矩阵如何影响学习过程。
- 假设我们有一些输入信号
,该信号首先乘以矩阵

- 上面第一步的结果,加上上一个状态与矩阵
相乘(矩阵 描述了所有内部状态如何连接)的结果,用来更新状态state

- 然后,使用矩阵
来将状态转换为输出

- 最后,再利用矩阵
提供从输入到输出的直接信号,这通常也称为跳跃连接skip-connection

- 由于矩阵
类似于跳跃连接,因此在没有跳跃连接的情况下,SSM 通常被视为如下

回到简化视角,现在可以关注只矩阵

这两个方程共同旨在根据观测数据预测系统的状态,且考虑到输入一般都是连续的,因此SSM的主要表示是连续时间表示( continuous-time representation )。

1.3 从SSM到S4
1.3.1 从连续到离散
由于除了连续的输入之外,还会通常碰到离散的输入(如文本序列),不过,就算SSM在离散数据上训练,它仍能学习到底层蕴含的连续信息,因为在SSM眼里,sequence不过是连续信号signal的采样,或者说连续的信号模型是离散的序列模型的概括。

利用零阶保持技术(Zero-order hold technique)处理离散化数据。

- 首先,每次收到离散信号时,都会保留其值,直到收到新的离散信号。此过程会创建 SSM 可以使用的连续信号
- 保持该值的时间由一个新的可学习参数表示,称为步长
,它代表输入的阶段性保持(resolution) - 有了连续的输入信号后,便可以生成连续的输出,并且仅根据输入的时间步长对值进行采样

最终够从连续 SSM 转变为离散SSM,使得不再是函数到函数

在保存时,仍然保存矩阵
1.3.2 循环结构表示:方便快速推理
总之,离散 SSM 允许可以用离散时间步长重新表述问题

在每个时间步,都会涉及到隐藏状态的更新(比如

展开一下
如此,便可以用RNN的结构来处理

然后可以这样展开(其中,

1.3.3 卷积结构表示:方便并行训练
在经典的图像识别任务中,用过滤器(即卷积核kernels)来导出聚合特征,而SSM也可以表示成卷积的形式

由于处理的是文本而不是图像,因此需要一维视角

而用来表示这个“过滤器”的内核源自 SSM 公式

- 与卷积一样,可以使用 SSM 内核来检查每组token并计算输出

- 内核将移动一次以执行下一步的计算

- 最后一步,可以看到内核的完整效果:

至于上图中的y_2是咋计算得到的,利用上面推导出来的
以此内推,可得
换个形式看,是不意味着
由于其中三个离散参数
至此,总结一下,将 SSM 表示为卷积的一个主要好处是它可以像卷积神经网络CNN一样进行并行训练。然而,由于内核大小固定,它们的推理不如 RNN 那样快速

现在可以使用循环 SSM 进行有效推理,并使用卷积 SSM 进行并行训练。
- 作为从输入信号到输出信号的参数化映射,SSMs可以当做是RNN与CNN的结合:即推理用RNN结构,训练用CNN结构

该模型称为线性状态空间层 (Linear State-Space Layer,LSSL)
- 总之,这类模型可以非常高效地计算为递归或卷积,在序列长度上具有线性或近线性缩放
1.3.4 长距离依赖问题的解决之道——HiPPO
如我们之前在循环表示中看到的那样,矩阵

其实,某种意义上,算是矩阵

由于矩阵
怎样才能以保留比较长的memory的方式创建矩阵
可以使用Hippo(Hippo的全称是High-order Polynomial Projection Operator,其对应的论文为:HiPPO: Recurrent Memory with Optimal Polynomial Projections),解决如何在有限的存
HiPPO尝试将当前看到的所有输入信号压缩为系数向量

它使用矩阵

具体表示可以如下图所示

正由于HiPPO 矩阵可以产生一个隐藏状态来记住其历史(从数学上讲,它是通过跟踪Legendre polynomial的系数来实现的,这使得它能够逼近所有以前的历史),使得在被应用于循环表示和卷积表示中时,可以处理远程依赖性
如此,S4的定义就出来了:序列的结构化状态空间——Structured State Space for Sequences,一类可以有效处理长序列的 SSM(S4所对应的论文为:Efficiently Modeling Long Sequences with Structured State Spaces)

且对矩阵
所有 HiPPO 矩阵都具有正态加低秩 (Normal Plus Low-Rank,NPLR)
表示
对于单位
1.4 SSM的问题:矩阵不随输入不同而变化,无法针对输入做针对性推理
1.4.1 SSM的问题
首先,Linear Time Invariance(LTI)规定 SSM中的
- 于 SSM 生成的每个token,矩阵
都是相同的 - 使得SSM无法针对输入做针对性的推理
此外,如下图所示,无论输入

同样,无论输入如何,

这里的不变性特指:推理时不随输入变化而变化,但在训练过程中,矩阵是可以根据需要去做梯度下降而变化的,具体来说,对于SSM和S4模型:
- 首先,对于训练过程:在训练时,模型会接收输入数据,并尝试预测输出。模型的参数(矩阵
的值)在每次迭代中通过梯度下降等优化算法进行调整,以便减少预测误差
这意味着矩阵的值会随着训练的进行而逐渐变化,以更好地适应数据 - 其次,对于推理过程:一旦模型训练完成,进入推理阶段,此时矩阵
的值将固定为训练结束时学习到的值。即在推理时,模型使用这些固定的矩阵来处理新的输入数据并生成预测
即无论是SSM,还是mamba,训练时 参数肯定会变 这点毫无疑问
- 但推理时,SSM不会随着输入的不同 做针对性的推理,即任何输入都是一视同仁,至于参数也不会变
- 但mamba会对输入做选择性推理,虽然推理时本身的参数也不会变,但会对不同的输入给予不同的有区别的对待,比如有的重点关注,有的选择性忽略
虽然Mamba模型在推理时参数本身也不变,但由于其设计中引入的选择性机制,使得模型能够根据输入数据的特点进行有区别的对待,这与SSM模型相比是一个显著的进步。且Mamba这种选择性是通过训练阶段的参数学习来实现的(根据训练阶段学习到的参数对不同的输入给予不同的处理),而不是在推理阶段动态调整参数
1.4.2 如何改进S4
比如 “I want to order a hamburger.”这句
- 如果没有选择性,S4会花费相同的“精力”来处理每个单词:

- 但如果是一个试图对这句话的意图进行分类的模型,它可能会想更多地“关注”order、hamburger,而不是want、to
如下图所示,而通过使模型参数成为输入的函数,模型就可以做到“专注于”输入中对于当前任务更重要的部分,而这正是mamba的创新点之一

凡事也有利有弊,虽然mamba可以“专注于”输入中对于当前任务更重要的部分,但坏处是没法再通过CNN做并行训练了,原因在于:
- 之前计算的卷积核
在S4中,我们可以预先计算该内核、保存,并将其与输入
- 但在Mamba中,这些矩阵会根据输入而变化。因此,我们无法预计算
,也无法使用CNN模式来训练我们的模型。从而下面这个式子 用不上了
说白了,如果想要选择性,得用RNN模式进行训练,而偏偏RNN的训练速度非常慢,所以需要找到一种无需卷积的并行训练方式。
二、Mamba
Mamba(其对应论文为:Mamba: Linear-Time Sequence Modeling with Selective State Spaces,GitHub代码),在语言、音频、DNA序列模态上都实现SOTA,在最受关注的语言任务上,Mamba-3B超越同等规模的Transformer,与两倍大的Transformer匹敌,并且相关代码、预训练模型checkpoint都已开源
简言之,Mamba是一种状态空间模型(SSM),建立在更现代的适用于深度学习的结构化SSM (简称S6)基础上,与经典架构RNN有相似之处。
与先前的研究相比,Mamba主要有三点创新:
- 对输入信息有选择性处理(Selection Mechanism)
- 硬件感知的算法(Hardware-aware Algorithm)
该算法采用“并行扫描算法”而非“卷积”来进行模型的循环计算(使得不用CNN也能并行训练),但为了减少GPU内存层次结构中不同级别之间的IO访问,它没有具体化扩展状态当然,这点也是受到了S5(Simplified State Space Layers for Sequence Modeling)的启发 - 更简单的架构
将SSM架构的设计与transformer的MLP块合并为一个块,来简化过去的深度序列模型架构,从而得到一个包含selective state space的架构设计
2.1 有选择处理信息
作者认为,序列建模的一个基础问题是把上下文压缩成更小的状态,从这个角度来看
transformer的注意力机制虽然有效果但效率不算很高,毕竟其需要显式地存储整个上下文(storing the entire context,也就是KV缓存),直接导致训练和推理消耗算力大
好比,Transformer就像人类每写一个字之前,都把前面的所有字+输入都复习一遍,所以写的慢RNN的推理和训练效率高,但性能容易受到对上下文压缩程度的限制。好比,RNN每次只参考前面固定的字数,写的快是快,但容易忘掉更前面的内容
而SSM的问题在于其中的矩阵
不随输入不同而不同,即无法针对不同的输入针对性的推理

最终,Mamba的解决办法是,相比SSM压缩所有历史记录,mamba设计了一个简单的选择机制,通过“参数化SSM的输入”,让模型对信息有选择性处理,以便关注或忽略特定的输入
这样一来,模型能够过滤掉与问题无关的信息,并且可以长期记住与问题相关的信息。好比,Mamba每次参考前面所有内容的一个概括,越往后写对前面内容概括得越狠,丢掉细节、保留大意
总之,序列模型的效率与效果的权衡点在于它们对状态的压缩程度:
- 高效的模型必须有一个小的状态(比如RNN或S4)
- 而有效的模型必须有一个包含来自上下文的所有必要信息的状态(比如transformer)
而Mamba为了兼顾效率和效果,选择性的关注必须关注的、过滤掉可以忽略的

2.1.1 S4的4个参数的不随输入不同而不同
具体来说,S4 模型由四个参数

且它们不随输入变化(即与输入无关),这些参数控制了以下两个阶段
第一阶段(1a 1b)通过固定公式
和 将“连续参数” 转换为“离散参数” ,其中对 称为离散化规则,且可以使用多种规则来实现这一转换。例如下述方程中定义的零阶保持(ZOH)
第二阶段(2a 2b,和3a 3b),在参数由
变换为 后,模型可以用两种方式计算,即线性递归(2)或全局卷积(3)如之前所说的
- 模型通常使用卷积模式(3)可以进行高效的并行化训练其中整个输入序列提前看到,为何可以做高效的并行化呢,因为该模式能够绕过状态计算,并实现仅包含(B, L, D)的卷积核(3a)
- 并切换到循环模式(2)以高效的自回归推理(其中输入每次只看到一个时间步)
2.1.2 S4中三个矩阵的维度表示、维度变化
再回顾一下,通过之前的讲解,可知

- 但为了对批量大小为
、长度为 ( ,比如类似上文举的例子中, )、具有 个通道的输入序列 进行操作(虽然在之前的示例中,每个token的维度设定的1,比如拿 一个 64 × 64维的矩阵A 去记 10000 × 1维的数字,但实际上,经常会遇到一个token不止一个维度的,比如颜色便有R G B三个通道,即embedding的dimension是 )。 则是输入和输出,和 Transformer 里面一样, 他们的大小是 (batch size x sequence length x embedding dim )

Mamba的处理方式是,给这 D 个 dimension的每个 dimension 都搞一个独立的 SSM,即SSM被独立地应用于每个通道。
- 这就解释了为什么下图中的
三个矩阵的第一个维度是都是

请注意,在这种情况下,每个输入的总隐藏状态具有
2.1.3 Mamba:从S4到S6的算法变化流程
在Mamaba中,作者让

- 从S4到S6的过程中
影响输入的
矩阵、影响状态的 矩阵的大小从原来的 变成了 , 指的是输入向量的维度,比如一个颜色的变量一般有R G B三个维度, 指SSM的隐藏层维度hidden dimension,当然 一般设的比较小,远小于 , 指的是序列长度。


且
的大小由原来的 变成了 意味着对于一个batch里的每个token(总共有 个)都有一个独特的 且每个位置的 矩阵、 矩阵、 都不相同,这意味着对于每个输入token,现在都有独特不同的 矩阵、 矩阵,可以解决内容感知问题
推理时参数本身还是不变,但由于参数是数据依赖的,模型在推理时可以根据输入数据的特点进行有区别的对待,即对不同的输入token应用不同的
维度上的变化具体执行时,是通过
、 、 和 ,其中 是参数化投影到维度 。选择 和 是因为与RNN门控机制有关。虽然
没有进行维度变化,但是通过SSM的离散化操作之后 会经过outer product变成 的张量,算是以一种parameter efficient的方式来达到维度变化的目的
,类似遗忘门
这个量跟RNN里的gating有着深刻的联系, 大则关注,小则忽略。跟 的 的 功 能 类 似 , 较小的步长
会更多地关注当前输入而不是上文会 忽 略 当 前 输 入 , 而 更 多 地 使 用 先 前 的 上 文 , 而 较 大 的 步 长 
如果某个输入比较重要 它的步长就更长些,被重点关注。如果某个输入不太重要它的步长就短,被直接忽略从而对于不同的输入,达到选择性关注或忽略的目标,做到详略得当主次分明。
起到的作用类似于:进RNN的memory。 起到的作用类似于:取RNN的memory
修改

意味着对应这个维度的SSM来说, 在每个hidden state维度上的作用可以不相同,起到multi-scale/fine-grained gating的作用,这也是LSTM网络里面用element-wise product的原因
Mamba通过合并输入的序列长度和批量大小来使矩阵
2.2 硬件感知的设计:并行扫描(parallel scan)
由于
为了实现并行化:
- 每个状态比如
都是前一个状态比如 乘以 ,加上当前输入 乘以 的总和,这就叫扫描操作(scan operation),可以使用 for 循环轻松计算,然这种状态之下想并行化是不可能的

- Mamba通过并行扫描(parallel
scan)算法使得最终并行化成为可能,其假设执行操作的顺序与关联属性无关因此,可以分段计算序列并迭代地组合它们,即动态矩阵
和 以及并行扫描算法一起创建选择性扫描算法(selective scan algorithm)

时间复杂度 O(n/t) 中的 t ,通常代表用于执行任务的处理器或计算单元的数量。所以才有,如果一个任务在单核上运行需要 O(n) 时间,则在 t 核上并行运行时,理想情况下可以将时间复杂度降低到O(n/t)
把相关推导再拆解一下,以更一目了然
- 首先,
和 的计算很简单,如下所示
* 其次,
- 最后,
最终包含了之前 、 以及 的信息,只是做了整体的压缩

此外,为了让传统的SSM在现代GPU上也能高效计算,Mamba中也使用了Flash Attention技术
- 简而言之,利用内存的不同层级结构处理SSM的状态,减少高带宽但慢速的HBM内存反复读写这个瓶颈
- 具体而言,就是限制需要从 DRAM 到 SRAM 的次数(通过内核融合kernel fusion来实现),避免一有个结果便从SRAM写入到DRAM,而是待SRAM中有一批结果再集中写入DRAM中,从而降低来回读写的次数
2.3 简化的SSM架构及最终的整体流程
将大多数SSM架构比如H3的基础块,与现代神经网络比如transformer中普遍存在的门控MLP相结合,组成新的Mamba块,重复这个块,与归一化和残差连接结合,便构成了Mamba架构。

关于mamba的整体架构,有两点值得强调下
为何要做线性投影
经过线性投影后,输入嵌入的维度可能会增加,以便让模型能够处理更高维度的特征空间,从而捕获更细致、更复杂的特征。
为什么SSM前面有个卷积
本质是对数据做进一步的预处理,更细节的原因在于:- SSM之前的CNN负责提取局部特征(因其擅长捕捉局部的短距离特征),而SSM则负责处理这些特征并捕捉序列数据中的长期依赖关系,两者算互为补充
- CNN有助于建立token之间的局部上下文关系,从而防止独立的token计算。毕竟如果每个 token 独立计算,那么模型就会丢失序列中 token 之间的上下文信息。通过先进行卷积操作,可以确保在进入 SSM 之前,序列中的每个 token 已经考虑了其邻居 token 的信息。这样,模型就不会单独地处理每个 token,而是在处理时考虑了整个局部上下文
最终在更高速的SRAM内存中执行离散化和递归操作,再将输出写回HBM,具体来说

- 不是在GPU HBM(高带宽内存)中将大小为
的扫描输入进 ,而是直接将SSM参数 从慢速HBM加载到快速SRAM中 - 然后,在SRAM中进行离散化,得到
的 - 接着,在SRAM中进行scan得到
的输出 - 最后,multiply and sum with
,得到 的最终输出写回HBM
2.4 Mamba的应用实例
2.4.1 通过mamba预测下一个token的示例
首先进行线性投影以扩展输入嵌入,然后,在应用选择性 SSM之前先进行卷积(如上节所说,以防止独立的token计算)

其中的“选择性SSM(即Selective SSM)”具有以下属性

- Recurrent SSM通过离散化创建循环SSM
- HiPPO对矩阵A进行初始化A以捕获长程依赖性
- 选择性扫描算法(Selective scan algorithm)选择性压缩信息
- 硬件感知算法(Hardware-aware algorithm)加速计算
最后,包含归一化层和用于选择“预测的token”的softmax
2.4.2 三个任务的对比:coping、selective copying、induction heads
如下图所示,有三个任务

- (左)复制任务的标准版本涉及输入和输出元素之间的固定间距,可以通过线性递归和全局卷积等时不变模型轻松解决
- (右上)选择性复制任务在输入之间具有随机间距,需要使用时变模型,在内容上能够灵活地选择记忆或忽略输入
相当于选择性复制任务通过改变“要记忆的tokens的位置”来改进纯粹的复制任务。它需要内容感知推理,以便能够记住相关的标记(有色),并过滤掉不相关的标记(白色)。 - (右下)归纳头部任务是联想回忆的一个例子,需要根据上下文检索答案,这是LLM关键的能力
其实,归纳头部任务是一种众所周知的机制,据推测可以解释LLMs的大部分上下文学习能力。它需要上下文感知的推理,以便知道何时在适当的上下文中产生正确的输出(黑色)。
三、 Mamba-2
Mamba-2(论文:Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,GitHub代码)。
Mamba-2通过将结构化SSM(S4)与注意力变体连接起来,称之为结构化状态空间对偶 (State Space Duality,SSD),该框架是通过结构化矩阵的抽象而形成的。作者提出了两个广泛的框架来表示序列模型,一个是矩阵变换,另一个是张量收缩,每种都揭示了对偶性的不同视角。技术贡献包括:
- 展示了SSM和半可分矩阵之间的等价性。计算状态空间模型的不同方法可以重新定义为结构化矩阵上的各种矩阵乘法算法。
- 改进了线性注意力理论,并推广到新的结构化掩蔽注意力(SMA)。
- 连接SSM和SMA,证明了它们之间有很大的交集,并且是彼此的对偶。
3.1 概述
Mamba-2 论文的要点是结构化状态空间对偶性(SSD),它指的是几个方面:
- SSD 模型是指特定的独立层,例如注意力或 SSM,可以合并到深度神经网络中
- SSD框架是推理该模型(以及更多理论联系)的通用框架
- SSD算法是一种比以前的SSM更高效地计算SSD层的算法
3.1.1 线性 (SSM) 模式
SSD 从与 Mamba 相同的一组方程开始:
结构化状态空间模型 (SSM)定义了
选择性状态空间模型允许
结构化SSM需要
3.1.2 SSD:标量结构化 SSM
最初的 Mamba(或者更准确地说是其核心“S6”层)正是具有对角结构的选择性
SSM。Mamba-2 的 SSD 层仅做了一个小修改:它将对角线
多头SSMS
方程
多个头可完全独立构造;假设正在使用一个头。请注意,这些头与多头注意力模型中的头的工作方式完全相同,并且在
Mamba-2 中,选择与现代 Transformer 类似的尺寸,例如
一些变化轴包括
上的结构影响其参数形状: 适用于一般(非结构化)SSM 用于对角 SSM(或其他结构,例如对角加低秩[1]) 用于标量 SSM(即 SSD)
- 状态维度
( ) - 头部尺寸
( )
结构化 SSM 还有其他变化轴(例如,时不变性与选择性、SISO 与 MIMO [4]、真实与复杂等)。
3.1.3 二次(注意力)模式
首先,暂时忘掉状态空间模型。给定上面相同的张量和相同的形状
然后,定义以下矩阵
最后,
3.1.4 状态空间对偶性(State Space Duality,SSD)
所谓“对偶性”是指方程
在通用 SSD 框架中,将以两种完全不同的方式展示这种等价性,这两种方式实际上都更加通用,并且都非常具有启发性。
SSD vs. SSM
与以前的 SSM 相比,SSD 与 Mamba 的核心层几乎相同,但在循环
- Mamba-1 (S6) 在
上使用对角线结构,而 Mamba-2 (SSD) 在 上使用标量乘以恒等结构。 - Mamba-1 的头部尺寸为
(即所有通道完全由单独的 SSM 独立控制),而 Mamba-2 使用的头部尺寸为 (类似于 )。
特别是,这可以通过两种方式被视为与权重相关:
- 通过将
的对角结构限制为标量乘以恒等式,递归动态在状态空间的所有 元素之间共享。 - 这些动态也会在给定头部的所有
通道之间共享。
换句话说,单个 SSM 头具有总状态大小
SSD vs. Attention
与标准(自)注意力相比,SSD 也只有两点不同:
- softmax 归一化被删除。
- 乘法应用单独的元素掩码矩阵。
第一个差异可以解释为将模型的有效状态大小从线性减小到恒定,并将其效率从二次提高到线性。第二个区别是
SSD
与标准线性注意力的区别。将掩码视为依赖于输入的相对位置编码的一种方法。由于
这可以解释为基于位置
3.2 SSD框架1:结构化矩阵变换
本节探讨了状态空间模型作为序列变换的不同视角,并概述了此类映射的属性和算法。本节的主要结果是关于状态空间模型与一类称为半可分矩阵的结构化矩阵之间的等价性。
3.2.1 矩阵变换
多序列模型,即序列变换
根据定义,
乘以
其中
画出来,这个矩阵看起来像
3.2.2 半可分离矩阵(Semiseparable Matrices)
如果下三角部分(即对角线上或对角线下)包含的每个子矩阵的秩最多为
定义算子
考虑任何非对角块
每个
3.2.3 1-半可分矩阵:标量 SSM 递归
1-SS 矩阵的重要性在于它们等价于标量递归的最小形式——状态维度
将矩阵乘以
3.2.4 状态空间模型是半可分的矩阵
状态空间模型变换
状态空间模型是半可分矩阵作为序列变换,状态空间模型可以表示为作用于序列维度
3.2.5 通过结构化矩阵算法计算状态空间模型
半可分矩阵(即秩结构矩阵)是一种经典的结构化矩阵:
- 它们具有压缩表示,例如 SSS 形式,其仅具有
而不是 参数。 - 它们具有直接在压缩表示上运行的快速算法。
大小为
张量收缩算法:其中维度
这里,
- 扩展 输入
乘以输入矩阵 - 展开独立标量 SSM 递归 第二步
- 收缩 隐藏状态
乘以输出矩阵
任何状态空间模型,其状态大小为
3.3 SSD框架2:结构化注意力
节的主要结果是基于张量收缩的线性注意力的简单证明,以及在SMA中对结构化掩蔽注意力的广义抽象。
- masked-attention为注意力变体建立了框架,特别关注内核注意力和掩蔽内核注意力。
- linear-attention提供了的第一个主要注意力结果,即通过张量收缩的视角对线性注意力的简单证明。
- structured-attention定义了结构化掩蔽注意力,这是通过结构化矩阵对先前注意力变体的概括。
3.3.1 Kernel Attention
将注意力定义为一个函数
由成对矩阵乘法给出
将
通常的注意力形式
限制
线性化注意力的最常见方法通常被视为矩阵乘法关联性的结果。
3.3.2 (Causal) Linear Attention
然而,一旦基本核注意力稍加修改,就不能再直接使用矩阵乘法的结合性了。
Linear Attention (LA)框架[6]。表明它仍然可以扩展到将因果关系纳入注意力的重要情况,用于自回归设置(例如语言建模)。
更明确地了解它是如何工作的。因果线性注意力的二次形式是
其中
是因果掩模矩阵。
一旦
线性注意力中累积和的出现完全等同于因果掩码
3.3.2.1 线性注意力的张量收缩(Tensor Contraction)证明
用张量收缩或 einsum 表示法非常明确地写出线性注意力的二次形式
通过这个符号,可以注意到这个收缩序列可以写成一个单一的四向收缩
最后,它可以用任何其他收缩排序来计算。特别是,可以对顺序
现在关键的观察是
3.3.3 Structured Masked Attention
从掩蔽注意力的张量收缩角度来看,原始线性注意力的关键在于矩阵向量乘以因果掩码等同于累积和运算符。注意力掩码没有理由必须全部为
结构化掩蔽注意力
(SMA)(或简称为结构化注意力)定义为对查询/键/值
SMA二次模式算法是由
SMA线性模式算法是由
可以将结构化掩蔽注意力实例化为任何给定的矩阵结构类。一些示例包括:
- 线性注意力使用因果掩码。
- RetNet 使用衰减掩码
,其中衰减因子 。 - 对于某些可学习(或输入相关)参数集
,衰减掩码可以推广到 Toeplitz 矩阵 。这可以解释为一种相对位置编码的形式,让人想起 AliBi 等其他方法,但采用乘法而不是加法。 - 另一种变体可以使用傅里叶矩阵
以不同的方式对位置结构进行编码。
在ssd中,考虑半可分离 SMA,它定义了的主要 SSD 模型。
SMA 为任何结构化矩阵
掩盖注意力的双重形式
标准(掩蔽核)注意力机制通常被混为一谈,既是函数又是算法。区分这种区别可以清晰地理解注意力机制的不同变体。
- 将 maskedtention 视为一个特定的function。
- 标准 quadratictention 计算可视为一种计算该函数的algorithm。
- Lineartention是计算相同函数的另一种算法。
此外,在这种情况下
- 掩蔽注意力函数仅仅是四个项的特定收缩。
- 二次和线性注意力算法仅仅是执行收缩的两种不同顺序。
众所周知,收缩顺序会对计算复杂度造成巨大影响,从而导致二次与线性分裂。正如状态空间模型是一种可以通过多种方式计算的变换,具有对偶二次与线性形式,线性注意力具有由两个收缩顺序产生的类似对偶性。
3.4 State Space Duality
在attention中,定义了结构化状态空间模型和结构化注意力,讨论了它们的属性,并表明它们都有二次算法和线性算法。本节将它们联系在一起。主要结果是表明结构化状态空间模型的一个特定情况与结构化注意力的一个特定情况相吻合,并且线性时间 SSM 算法和二次时间核注意力算法是彼此的对偶形式。
3.4.1 标量恒等结构状态空间模型
现在让考虑
这可以矢量化为
其中
其中
因此,通过具体化半可分矩阵
3.4.2 1-半可分离结构化掩蔽注意力
1-SS SMA(具有 1-半可分结构矩阵
1-半可分离结构化注意力是 SMA 最重要的情况,因为它:
- 具有输入相关递归的线性注意力的自然泛化。
- 一般半可分离注意力的最简单情况,相当于有效的自回归注意力。
- 对角状态空间模型的特殊情况。
3.4.3 结构化状态空间对偶 (SSD)
总结一下结果:
- 结构化状态空间模型通常通过线性时间递归定义。
但是,通过扩展表征其线性序列到序列变换的矩阵公式,可以得出二次形式。 - 注意力变体是通过二次时间成对交互定义的模型。但是,通过将其视为四向张量收缩并以不同的顺序减少,可以得出线性形式。
- 每一个的自然特例——
更准确地说,在 矩阵上具有标量恒等结构的状态空间模型,以及在其 掩码上具有 1 半可分离结构的结构化掩码注意力
——是彼此的对偶,具有完全相同的线性和二次形式。
结构化状态空间对偶
状态空间对偶描述了状态空间模型和掩蔽注意力之间的密切关系。(左) 一般 SSM 和 SMA 都具有线性和二次形式,符号上有直接类似物。(右) SSM 和 SMA 在一大类状态空间对偶模型(SSD) 中相交,这些模型将许多序列模型作为特例捕获。
3.5 SSD 模型的硬件高效算法
首先,将矩阵
最简单的例子是,
从这里可以将问题简化为这两个部分。这也可以解释为将“块”
3.5.1 Diagonal Blocks
对角块很容易处理,因为它们只是较小规模的自相似问题。第
这些子问题可以解释为:假设初始状态(对于块而言)为
3.5.2 Low-Rank Blocks
低秩分解由 3 个项组成,因此计算有 3 个部分。在此分解中,将使用术语
- 像
这样的项被称为右因子或 -block-factors。 - 像
这样的项被称为中心因子或 -block-factors。 - 像
这样的项被称为左因子或 -block-factors。
Right Factors.
此步骤计算低秩分解的右
这可以解释为:假设初始状态(对于块)为
enter Factors.
此步骤计算低秩分解中的中心
此步骤可以通过任何计算 1-SS 乘法的算法(也称为标量 SSM 扫描或 运算符)来计算。
这可以解释为:每个块的实际最终状态是什么
考虑到所有先前的输入;换句话说,这将计算真实隐藏状态
Left Factors.
此步骤计算低秩分解的左
这可以解释为:每个块的输出是多少 考虑到正确的初始状态
SSD算法:分块矩阵分解
离散化;可以随心所欲地计算这个乘法;特别是,使用 SSD
的二次(类似注意力)形式。
- (绿色)总共只有
个不同的绿色块,因为其中许多是共享的。这些可以使用批处理 matmul 来计算。 - (黄色)请注意,黄色项本身形成一个
1-半可分离矩阵;换句话说,这一步相当于SSM扫描(在某些修改的
因素上)! - (蓝色)与绿色类似,这些可以使用批处理 matmul 来计算。
SSD 算法:分块和状态传递
该算法的另一种解释涉及推理 SSM
如何对实际序列进行操作。首先将输入序列分割成大小为
1.
块内输出:计算每个块的本地输出(假设(块的)初始状态为
0,每个块的输出是多少?)
2. 块状态:计算每个块的最终状态(假设(块的)初始状态为
0,每个块的最终状态是什么?)
3. 通过状态:计算所有块的最终状态的递归 -
使用任何所需的算法,例如并行或顺序扫描(考虑到所有先前的输入,每个块的实际最终状态是什么?)
4. 输出状态:对于每个块,给定其真实的初始状态(在步骤 3
中计算),计算初始状态对输出的贡献
不管怎样,大多数算法(步骤 1、2 和 4)利用了 matmuls(以及张量核心),并且也可以完全并行计算!只有第 3 步需要扫描,但它的操作序列要短得多,并且通常只需要整个算法时间的一小部分。
3.5.3 Computational Cost
定义符号
- 计算成本:总计
FLOPs。 - 内存成本:总计
空间。 - 并行化:更大的
项可以利用现代加速器上的专门矩阵乘法单元。、 、
1 | |
Center Blocks.
二次 SMA 计算的成本包括三个步骤:
- 计算核矩阵
,其成本为 。 - 乘以掩码矩阵,这是对形状为
的张量的元素运算。 - 乘以
值,其成本为
Low-Rank Blocks: Right Factors.
此步骤是长度为
请注意,由于阻塞将序列长度从
Low-Rank Blocks: Left Factors.
此步骤是单矩阵乘法,成本为
Total Cost.
如果设置
- 总 FLOP 数为
。 - 总内存为
。 - 工作 主要包括形状为
的矩阵的矩阵乘法。
请注意,内存消耗很紧;输入和输出
除了 matmuls 之外,还有一个标量 SSM 扫描,扫描
Comparison to Pure SSM and Attention Models.
二次注意力机制也非常高效地利用了硬件,因为它只利用了矩阵乘法,但总
FLOP 为
线性 SSM 的总 FLOP 为
| Attention | SSM | SSD | |
|---|---|---|---|
| State size | |||
| Training FLOPs | |||
| Inference FLOPs | |||
| (Naive) memory | |||
| Matrix multiplication |
许多其他矩阵分解都是可能的(例如,了解通过不同结构化矩阵分解进行 1-SS 乘法的算法概要)这可能导致更多 SSD 算法,这些算法可能更适合其他专门设置。更广泛地说,除了使用的 SSS 形式之外,半可分矩阵还有丰富的文献和更多的表示形式,甚至可能存在更高效的算法。
3.5.4 Stability
尝试 1:CUMPRODS 比率
第一次天真的尝试可能是注意到该矩阵的条目是累积乘积
然而,这会遇到严重的数字问题,因为这些乘积可能变得非常小(想象一下
修复 1:段求和 ( SEGSUM ) 运算
第二次尝试是在日志空间中完成所有这些,因为所有cumsum代替要处理的 cumprod 。然后,为了计算
1-SS 矩阵,只需计算每个段
尝试 2:CUMSUMS 的差异
再次执行此操作的明显方法是使用与上面相同的想法,但是在日志空间中
沿时间轴计算
1 | |
(然后 1-半可分离矩阵就是该输出的指数)。和/差比乘积/商稳定得多,所以这应该可行——对吧?
修复 2:删除所有减法
不幸的是,事实证明这仍然行不通。这个 1-SS 矩阵的值大致代表了 SSM
动态,它对
尝试 3:稳定 SEGSUM
这引出了参考 SSD 代码中的辅助函数。找到了一种方法来使用一批独立的累加值,无需减法即可立即产生正确的答案,而不是计算单个累加值然后进行减法。
这些细节很重要!如果没有正确实现这些原语,基本 SSD 算法会在训练期间立即生成 NaN(即使使用 FP32)。
离散化
这种结构化状态空间模型的谱系是从 S4 及其前身发展而来的,它们被视为连续时间系统。然而,在 Mamba 中,不再真正认为 SSM 是连续的。事实上,正如原始论文的讨论(第 5 节)中提到的,Mamba 在建模不同类型的数据时与 S4 进行了权衡:
- S4 是一种连续时间模型,擅长对连续数据进行建模,例如感知信号,例如音频波形和像素级视觉。
- Mamba 是一种离散时间模型,擅长对离散数据进行建模,例如标记化数据,例如语言。
然而,Mamba 的参数化仍然使用与先前结构化 SSM
相同的离散化步骤,其中有另一个参数
先前关于结构化 SSM 的理论的初始化和参数化仍然是开箱即用的,那么为什么要修复没有损坏的东西呢?
尽管如此,离散化步骤对于 Mamba 来说并不是真正必要的。在 Mamba-2
论文中,选择直接使用“离散参数”
这不会造成任何问题:要使用连续 SSM 参数化,只需在插入上面的 SSD 代码之前通过上述公式转换参数即可。
在完整的 Mamba-2 代码中,还保留了与 Mamba 相同的参数化和离散化步骤——同样,为什么要修复没有损坏的东西?——但假设“以离散为中心”的变体(例如 LRU和 Griffin 的伽玛归一化)应该同样有效。
### 3.6 Mamba-2 Architecture
通过连接 SSM 和注意力机制,SSD 框架能够为两者开发一个共享的词汇表和技术库。在本节中,将讨论一些使用最初为 Transformers 开发的想法来理解和修改 SSD 层的示例。
Mamba-2 块通过删除连续线性投影简化了 Mamba 块;SSM 参数
3.6.1 Block Design
首先讨论对独立于内部序列混合层(即核心 SSD 层之外)的神经网络块的修改。
Parallel Parameter Projections.
Mamba-1 的动机是 SSM 中心观点其中选择性 SSM 层被视为从
在 Mamba-2 中,SSD 层被视为从
请注意,通过使用标准 Megatron 分片模式,对 SSM 的
Extra Normalization.
在初步实验中,发现较大的模型容易出现不稳定性。能够通过在最终输出投影之前向块添加额外的规范化层(例如 LayerNorm、GroupNorm 或 RMSNorm)来缓解这种情况。这种规范化的使用与 NormFormer 架构最直接相关,该架构还在 MLP 和 MHA 块的末尾添加了规范化层。
还注意到,这种变化类似于从线性注意观点得出的与 Mamba-2 相关的其他最新模型。原始线性注意公式通过分母项进行规范化,该分母项模拟标准注意中 softmax 函数的规范化。TransNormerLLM和 RetNet发现这种规范化是不稳定的,并在线性注意层之后添加了额外的 LayerNorm 或 GroupNorm。额外规范化层与这些略有不同,它出现在乘法门分支之后而不是之前。
3.6.2 Multihead Patterns for Sequence Transformations
回想一下,SSM 被定义为序列变换其中:
参数具有状态维度 。- 它们定义了一个序列变换
,例如可以表示为矩阵 。 - 此变换对输入序列
进行操作,独立于 轴。
可以将其视为定义序列变换的一个head。多头序列变换由
状态大小
Multihead SSM (MHS) / Multihead Attention (MHA) Pattern.
经典的 MHA 模式假设头部维度
Multi-contract SSM (MCS) / Multi-query Attention (MQA) Pattern.
多查询注意力是一种巧妙的注意力优化方法,可以显著提高自回归推理的速度,它依赖于缓存
使用状态空间对偶,可以将 MQA 的等效 SSM 版本定义为方程Multi-contract
SSM。这里,
可以类似地定义一个多键注意 (MKA) 或多扩展 SSM (MES)
头部模式,其中
Multi-input SSM (MIS) / Multi-value Attention (MVA) Pattern.
虽然 MQA 因其 KV 缓存而适合用于注意力机制,但它并不是 SSM
的自然选择。在 Mamba 中,
有了这个词汇表,可以更准确地描述原始的 Mamba 架构。
Mamba 架构的选择性 SSM(S6)层可以被视为具有:
- 头部维度
:每个通道都有独立的 SSM 动态 。 - 多输入 SSM (MIS) 或 多值注意 (MVA)
头部结构:
矩阵(对应注意对偶中的 )在输入 (对应注意中的 )的所有通道之间共享。
当应用于 SSD 时,还可以消除这些头部模式变体。有趣的是,尽管在参数数量和总状态维度方面受到控制,但下游性能仍存在明显差异。通过经验发现,最初在 Mamba 中使用的 MVA 模式表现最佳。
Grouped Head Patterns.
多查询注意力的思想可以扩展到分组查询注意力:除了
类似地,Mamba-2 中使用的多输入 SSM 头模式可以轻松扩展到分组输入 SSM (GIS),或同义地分组值注意力 (GVA)。概括很简单,为了简单起见,省略了细节。
3.6.3 Other SSD Extensions from Linear Attention
将内核特征图的选择视为 Mamba-2 架构中的超参数,并期望其他受注意力机制启发的简单修改也有可能实现。
Kernel Attention Approximations to Softmax Attention.
线性注意力或核注意力的许多变体都是通过将注意力得分
- 指数核
,对于某些核特征图,可以通过 来近似。 - 通过
对核进行归一化,使得行总和为 ,其中除法按元素进行,并且 是全 1 的向量。
在 Mamba-2 中,整合了一个灵活的内核特征图,并将其应用于
Incorporating a Normalization (Denominator) Term.
为了找到分母项,只需计算
请注意,在这种情况下,核特征图
3.7 Systems Optimization for SSMs
描述了针对 SSM 的几种系统优化,特别是 Mamba-2 架构,以实现大规模高效训练和推理。具体来说,专注于大规模训练的张量并行和序列并行,以及高效微调和推理的可变长度序列。
3.7.1 Tensor Parallel
张量并行 (TP)是一种模型并行技术,它将每一层(例如注意力、MLP)拆分为多个加速器(例如 GPU)运行。该技术广泛用于在 GPU 集群上训练大多数大型模型,其中每个节点通常有 4-8 个 GPU,并具有 NVLink 等快速网络。TP 最初是为 Transformer 架构开发的,将其改编成其他架构并不简单。首先展示了将 TP 与 Mamba 架构结合使用的挑战,然后展示了 Mamba-2 架构如何设计来提高 TP 效率。
回想一下 Mamba 架构,它具有单个输入
使用 TP,假设想将计算拆分到 2 个 GPU 上。很容易将输入投影矩阵
使用 Mamba-2,目标是每个块只有一个全归约,类似于 Transformers
中的注意力或 MLP 块。因此,可以直接从
只需要拆分输入投影矩阵和输出投影矩阵,并且只需要在块的末尾进行全归约。这类似于
TP 的注意力和 MLP 层的设计。具体来说,如果 TP 度为 2,会将
在mamba2_parallelism (左) 中说明了使用 Mamba-2 的张量并行。
分割输入投影矩阵
3.7.2 Sequence Parallelism
对于非常长的序列,可能需要沿序列长度维度将输入和激活拆分到不同的 GPU。有两种主要技术:
- 残差和归一化操作的序列并行 (SP):将 TP 中的 all-reduce 分解为 Reduce-scatter 和 All-gather。注意到残差和归一化操作在同一 TP 组中的所有 GPU 上重复执行相同的输入,SP 通过执行以下操作沿序列长度维度拆分激活:Reduce-scatter、残差和归一化,然后执行 All-gather。
由于 Mamba-2 架构使用相同的残差和归一化结构,因此 SP 无需修改即可应用。
- 标记混合操作(注意或
SSM)的序列并行,也称为“上下文并行”(CP)。
注意力机制中的序列并行的难点在于,可以将查询和键拆分成块,但每个查询块都需要与键块交互,从而导致通信带宽是工作器数量的二次方。
使用 SSM,可以以一种简单的方式拆分序列:每个工作器采用初始状态,根据其输入计算 SSM,返回最终状态,并将该最终状态传递给下一个工作器。通信带宽与工作器数量成线性关系。这种分解与 SSD 算法 (ssd-algorithm) 中拆分成块/块的块分解完全相同。在图mamba2_parallelism (右) 中说明了这种上下文并行。
3.7.3 Variable Length
虽然预训练通常对批次使用相同的序列长度,但在微调或推理期间,模型可能需要处理不同长度的不同输入序列。处理这种情况的一种简单方法是将批次中的所有序列右填充到最大长度,但如果序列长度相差很大,这种方法效率会很低。对于transformer,已经开发出复杂的技术来避免填充并在
GPU
之间进行负载平衡,或者将多个序列打包在同一批次中并调整注意力掩码。特别是使用
SSM 和
Mamba,可以通过将整个批次视为一个长序列来处理可变序列长度,并避免在各个序列之间传递状态。这相当于简单地将一个序列末尾的标记
3.8 结果
更快的 SSD 算法允许增加状态维度(与 Mamba-1 中的
将 Mamba 层与注意力层相结合可以比纯 Transformer 或 Mamba 有所改进。在 2.7B 参数和 300B 令牌规模下验证,仅具有 6 个注意力块(和 58 个 SSD 块)的混合模型优于 64 个 SSD 块以及标准 Transformer++ 基线(32 个门控 MLP 和 32 个注意力块)。

还验证了对于相同的状态维度,SSD 算法比 Mamba-1 的选择性扫描算法要快得多,并且在计算上可以更好地扩展到更大的状态维度。让这些张量核心发挥作用是关键!
参考
- 一文通透想颠覆Transformer的Mamba:从SSM、HiPPO、S4到Mamba
- A Visual Guide to Mamba and State Space Models
- Mamba: Linear-Time Sequence Modeling with Selective State Spaces
- Efficiently Modeling Long Sequences with Structured State Spaces
- Mamba原理最通俗介绍火了,一文看懂“Transformer挑战者”两大主要思想!网友:年度最佳解读
- MedAI #41: Efficiently Modeling Long Sequences with Structured State Spaces | Albert Gu
- RWKV:Transformer时代的RNN模型
- 如何理解 Mamba 模型 Selective State Spaces?
- [线性RNN系列] Mamba: S4史诗级升级
- Combining Recurrent, Convolutional, and Continuous-time Models with Linear State Space Layers
- S4: 使用结构化状态空间对长序列进行高效建模
- 挑战 Transformer:全新架构 Mamba 详解
- Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
- State Space Duality (Mamba-2)
- Transformers are SSMs:
Generalized Models and Efficient Algorithms Through Structured State
Space Duality