一句话:Mamba 是 Linear Attention 的"升级版"——同样是 $O(N)$ 复杂度、固定大小隐状态,但通过选择性机制(Selective SSM) 让模型能动态决定"记住什么、忘记什么",效果接近 Transformer,推理速度接近 RNN。
一、从 Linear Attention 到 Mamba:解决什么问题?
回顾 Linear Attention 的递推公式:
$$S_t = S_{t-1} + \phi(k_t) v_t^T$$
问题:衰减是固定的(没有衰减,或者 RetNet 里用固定的 $\gamma$),模型无法根据输入内容动态决定"这个历史信息重不重要"。
类比:
RNN(LSTM):有遗忘门,可以选择性地清除历史
Linear Attention:没有遗忘门,历史信息只增不减,全部堆进 S 矩阵
RetNet:有固定衰减 γ,但 γ 是超参数,不随输入变化
Mamba:衰减因子由输入动态生成,每个 token 的"遗忘程度"不同
二、SSM 的数学基础
Mamba 基于状态空间模型(State Space Model,SSM),这是控制论里的经典框架。
连续时间 SSM
$$h’(t) = A h(t) + B x(t)$$ $$y(t) = C h(t)$$
- $x(t)$:输入信号
- $h(t)$:隐状态(类比 RNN 的 hidden state)
- $y(t)$:输出
- $A$:状态转移矩阵(控制历史信息如何演化)
- $B$:输入投影矩阵(控制输入如何影响隐状态)
- $C$:输出投影矩阵(控制隐状态如何映射到输出)
离散化(实际使用的形式)
连续 SSM 需要离散化才能用于序列建模,使用零阶保持(ZOH) 方法:
$$\bar{A} = e^{\Delta A}, \quad \bar{B} = (e^{\Delta A} - I) A^{-1} B \approx \Delta B$$
其中 $\Delta$(dt_bias 对应的参数)是步长(step size),控制离散化的粒度。
离散化后的递推公式:
$$h_t = \bar{A} h_{t-1} + \bar{B} x_t$$ $$y_t = C h_t$$
这和 RNN 的结构完全一样!
RNN: h_t = tanh(W_h h_{t-1} + W_x x_t)
SSM: h_t = Ā h_{t-1} + B̄ x_t
区别: SSM 的 Ā 有特殊结构(来自连续系统的离散化),更有理论保证
三、S4:Mamba 的前身
S4(Structured State Space Sequence Model,2021) 是 Mamba 的直接前身。
S4 的关键设计:把 $A$ 矩阵限制为对角加低秩(DPLR) 结构,使得:
- 可以高效并行计算(卷积形式)
- 可以高效递推(RNN 形式)
- 理论上能捕获超长距离依赖
S4 的问题:$A$、$B$、$C$ 都是固定参数,不随输入变化——内容无关(content-unaware)。
S4 处理"今天天气很好"和"今天天气很差"时,
用的是完全相同的状态转移矩阵 Ā,
模型无法根据"好"还是"差"来决定记住多少。
四、Mamba 的核心创新:选择性机制
Mamba(2023,Albert Gu & Tri Dao)的核心贡献:让 $B$、$C$、$\Delta$ 依赖于输入 $x_t$。
对比 S4 和 Mamba
| 参数 | S4 | Mamba |
|---|---|---|
| $A$ | 固定(训练后不变) | 固定(但结构特殊) |
| $B$ | 固定 | 由 $x_t$ 动态生成 |
| $C$ | 固定 | 由 $x_t$ 动态生成 |
| $\Delta$ | 固定 | 由 $x_t$ 动态生成 |
选择性机制的直觉
$$B_t = \text{Linear}_B(x_t), \quad C_t = \text{Linear}C(x_t), \quad \Delta_t = \text{softplus}(\text{Linear}\Delta(x_t))$$
$$\bar{A}_t = e^{\Delta_t A}, \quad \bar{B}_t = \Delta_t B_t$$
$$h_t = \bar{A}t h{t-1} + \bar{B}_t x_t$$ $$y_t = C_t h_t$$
$\Delta_t$ 的作用(这就是 config.json 里的 dt_bias):
Δ_t 很大 → Ā_t ≈ 0,B̄_t ≈ B_t
→ h_t ≈ B_t x_t(几乎忘掉历史,专注当前输入)
→ 相当于"重置"隐状态
Δ_t 很小 → Ā_t ≈ I,B̄_t ≈ 0
→ h_t ≈ h_{t-1}(几乎忽略当前输入,保留历史)
→ 相当于"跳过"当前 token
这就是 Mamba 的选择性:模型学会了对重要 token 用大 $\Delta$(重置并记住),对不重要 token 用小 $\Delta$(直接跳过)。
五、A 矩阵:HiPPO 初始化
$A$ 矩阵虽然固定,但初始化方式很关键。Mamba 使用 HiPPO(High-order Polynomial Projection Operators) 初始化:
$$A_{nk} = -\begin{cases} (2n+1)^{1/2}(2k+1)^{1/2} & n > k \ n+1 & n = k \ 0 & n < k \end{cases}$$
直觉:HiPPO 矩阵被设计为能最优地压缩历史信息——隐状态 $h$ 相当于对历史输入的多项式近似系数,理论上能记住任意长距离的依赖。
这就是 config.json 里 linear_attn.A_log 的来源:$A$ 以对数形式存储(A_log = log(-A)),保证离散化后 $\bar{A}$ 的特征值在单位圆内(系统稳定)。
六、conv1d 的作用
config.json 里还有 linear_attn.conv1d,这是 Mamba 的另一个设计:
在 SSM 之前,先做一个短程卷积:
x_t → conv1d(kernel_size=4)→ SSM → y_t
为什么需要 conv1d?
SSM 的隐状态是全局的(压缩了所有历史),但对局部特征(如 n-gram、短语结构)不敏感。conv1d 用小卷积核(kernel_size=4,对应 config 里的 linear_conv_kernel_dim: 4)捕获局部模式,作为 SSM 的补充。
conv1d:捕获局部特征(短程)
SSM: 捕获全局依赖(长程)
两者结合:覆盖所有尺度的依赖
七、Mamba Block 的完整结构
输入 x [seq_len, d_model]
│
├─────────────────────────────┐
│ │
Linear(d_model → d_inner) Linear(d_model → d_inner)
│ │
SiLU │
│ │
conv1d(kernel=4) │
│ │
SiLU │
│ │
SSM(选择性状态空间) │
┌─────────────────────┐ │
│ B_t = Linear_B(x_t) │ │
│ C_t = Linear_C(x_t) │ │
│ Δ_t = Linear_Δ(x_t) │ │
│ h_t = Ā_t h_{t-1} + B̄_t x_t │ │
│ y_t = C_t h_t │ │
└─────────────────────┘ │
│ │
└──────── × ──────────────────┘
│(门控:SSM输出 × 线性分支)
│
Linear(d_inner → d_model)
│
输出 y
门控设计:右边的线性分支类似 SwiGLU 里的门控,让模型能选择性地"放大"或"抑制" SSM 的输出。
八、完整代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelectiveSSM(nn.Module):
"""
Mamba 的核心:选择性状态空间模型(Selective SSM)。
d_model: 输入维度
d_state: 隐状态维度(论文中通常为 16)
"""
def __init__(self, d_model: int, d_state: int = 16):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# A 矩阵:以对数形式存储,保证离散化后稳定
# 初始化为 log(1, 2, ..., d_state),近似 HiPPO
A = torch.arange(1, d_state + 1, dtype=torch.float32).unsqueeze(0).expand(d_model, -1)
self.A_log = nn.Parameter(torch.log(A)) # [d_model, d_state]
# B、C、Δ 由输入动态生成
self.linear_B = nn.Linear(d_model, d_state, bias=False)
self.linear_C = nn.Linear(d_model, d_state, bias=False)
self.linear_delta = nn.Linear(d_model, d_model, bias=True) # dt_bias 在这里
# D:跳跃连接(直接将输入加到输出)
self.D = nn.Parameter(torch.ones(d_model))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: [batch, seq_len, d_model]
返回: [batch, seq_len, d_model]
"""
batch_size, seq_len, d_model = x.shape
d_state = self.d_state
# 恢复 A(负数,保证系统稳定)
A = -torch.exp(self.A_log) # [d_model, d_state]
# 动态生成 B、C、Δ
B = self.linear_B(x) # [batch, seq_len, d_state]
C = self.linear_C(x) # [batch, seq_len, d_state]
delta = F.softplus(self.linear_delta(x)) # [batch, seq_len, d_model],保证 Δ > 0
# 离散化(ZOH,零阶保持):
# Ā = exp(Δ·A)
# B̄ = (ΔA)⁻¹ · (exp(ΔA) - I) · ΔB
# 因为 A 是对角矩阵(每个 d_model 维度独立对应 d_state 个状态),
# (ΔA)⁻¹ · (exp(ΔA) - I) 可以逐元素计算,无需矩阵求逆:
# B̄ = (exp(ΔA) - 1) / A · B = (Ā - 1) / A · B
#
# delta: [batch, seq_len, d_model]
# A: [d_model, d_state]
# delta_A: [batch, seq_len, d_model, d_state]
delta_A_product = delta.unsqueeze(-1) * A # ΔA,[batch, seq_len, d_model, d_state]
delta_A = torch.exp(delta_A_product) # Ā = exp(ΔA),[batch, seq_len, d_model, d_state]
# ZOH 的 B̄:逐元素 (exp(ΔA) - 1) / A · B
# 当 A 接近 0 时用 Taylor 展开近似(数值稳定),但实际中 A 初始化为负整数,不会为 0
delta_B_zoh = (delta_A - 1.0) / A # (exp(ΔA) - 1) / A,[batch, seq_len, d_model, d_state]
delta_B = delta_B_zoh * B.unsqueeze(2) # B̄ = delta_B_zoh · B,[batch, seq_len, d_model, d_state]
# 递推计算隐状态
# h: [batch, d_model, d_state]
h = torch.zeros(batch_size, d_model, d_state, device=x.device)
outputs = []
for t in range(seq_len):
# h_t = Ā_t ⊙ h_{t-1} + B̄_t ⊙ x_t
h = delta_A[:, t] * h + delta_B[:, t] * x[:, t].unsqueeze(-1)
# y_t = C_t · h_t(对 d_state 维度求和)
y_t = (C[:, t].unsqueeze(1) * h).sum(dim=-1) # [batch, d_model]
outputs.append(y_t)
y = torch.stack(outputs, dim=1) # [batch, seq_len, d_model]
# 跳跃连接:y += D * x
y = y + self.D * x
return y
class MambaBlock(nn.Module):
"""
完整的 Mamba Block,包含门控结构和 conv1d。
d_model: 模型维度
d_inner: 内部扩展维度(通常为 d_model * 2)
d_state: SSM 隐状态维度
conv_kernel: 局部卷积核大小
"""
def __init__(self, d_model: int, d_inner: int = None, d_state: int = 16, conv_kernel: int = 4):
super().__init__()
self.d_model = d_model
self.d_inner = d_inner or d_model * 2
self.norm = nn.LayerNorm(d_model)
# 输入投影:分成两路(SSM 路 + 门控路)
self.input_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
# 局部卷积(捕获短程特征)
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
kernel_size=conv_kernel,
padding=conv_kernel - 1, # causal padding
groups=self.d_inner, # depthwise conv
bias=True
)
# 选择性 SSM
self.ssm = SelectiveSSM(self.d_inner, d_state)
# 输出投影
self.output_proj = nn.Linear(self.d_inner, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: [batch, seq_len, d_model]
返回: [batch, seq_len, d_model](残差连接在外部处理)
"""
residual = x
x = self.norm(x)
# 分成两路
projected = self.input_proj(x) # [batch, seq_len, d_inner*2]
ssm_branch, gate_branch = projected.chunk(2, dim=-1) # 各 [batch, seq_len, d_inner]
# SSM 路:conv1d → SiLU → SSM
# conv1d 需要 [batch, channels, seq_len] 格式
ssm_branch = ssm_branch.transpose(1, 2) # [batch, d_inner, seq_len]
ssm_branch = self.conv1d(ssm_branch)[..., :x.shape[1]] # causal: 截掉多余的 padding
ssm_branch = ssm_branch.transpose(1, 2) # [batch, seq_len, d_inner]
ssm_branch = F.silu(ssm_branch)
ssm_branch = self.ssm(ssm_branch) # [batch, seq_len, d_inner]
# 门控路:SiLU
gate_branch = F.silu(gate_branch)
# 门控融合
output = ssm_branch * gate_branch # [batch, seq_len, d_inner]
# 输出投影 + 残差
output = self.output_proj(output) # [batch, seq_len, d_model]
return output + residual
# ── 验证 ──────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
torch.manual_seed(42)
batch_size, seq_len, d_model = 2, 16, 64
block = MambaBlock(d_model=d_model, d_inner=128, d_state=16, conv_kernel=4)
x = torch.randn(batch_size, seq_len, d_model)
y = block(x)
print(f"输入形状: {x.shape}") # [2, 16, 64]
print(f"输出形状: {y.shape}") # [2, 16, 64]
print(f"无 NaN: {not torch.isnan(y).any().item()}")
# 验证推理时的递推(单步生成)
ssm = SelectiveSSM(d_model=64, d_state=16)
single_token = torch.randn(1, 1, 64)
single_out = ssm(single_token)
print(f"单步推理: {single_out.shape}") # [1, 1, 64]
九、训练时的并行化:Parallel Scan
递推形式在训练时是串行的,效率低。Mamba 的另一个贡献:当 $\bar{A}$ 不依赖输入时(S4),可以用卷积并行计算。
对于固定的 $\bar{A}$,展开递推:
$$h_t = \bar{A}^t h_0 + \sum_{i=0}^{t} \bar{A}^{t-i} \bar{B} x_i$$
$$y_t = C h_t = \sum_{i=0}^{t} \underbrace{C \bar{A}^{t-i} \bar{B}}{\text{卷积核 } K{t-i}} x_i$$
这就是一个因果卷积!可以用 FFT 在 $O(N \log N)$ 时间内并行计算。
Mamba 的选择性机制打破了这个并行化(因为 $\bar{A}_t$ 依赖输入,不再固定),所以 Mamba 使用了 Parallel Scan(并行前缀扫描) 算法。
Parallel Scan 的核心思想
SSM 的递推 $h_t = \bar{A}t h{t-1} + \bar{B}_t x_t$ 是一个结合性(associative) 操作,可以用分治法并行化:
把相邻两步的递推合并成一步:
$$\begin{pmatrix} h_t \ 1 \end{pmatrix} = \begin{pmatrix} \bar{A}_t & \bar{B}t x_t \ 0 & 1 \end{pmatrix} \begin{pmatrix} h{t-1} \ 1 \end{pmatrix}$$
令 $e_t = (\bar{A}_t,\ \bar{B}_t x_t)$ 为第 $t$ 步的"元素",定义结合操作 $\oplus$:
$$(a_2, b_2) \oplus (a_1, b_1) = (a_2 \cdot a_1,\ a_2 \cdot b_1 + b_2)$$
则 $h_t$ 等价于对 $e_1, e_2, \ldots, e_t$ 做前缀扫描(prefix scan)。前缀扫描可以用 $O(\log N)$ 轮并行归约完成,总复杂度 $O(N \log N)$,在 GPU 上每轮内部完全并行。
def parallel_scan_ssm(delta_A: torch.Tensor, delta_B_x: torch.Tensor) -> torch.Tensor:
"""
用 Parallel Scan(并行前缀扫描)计算 SSM 的所有隐状态。
核心思想:把 h_t = A_t * h_{t-1} + b_t 的串行递推,
转化为对 (A_t, b_t) 对的结合性前缀扫描,实现并行计算。
delta_A: [batch, seq_len, d_model, d_state],离散化后的 Ā_t
delta_B_x: [batch, seq_len, d_model, d_state],B̄_t * x_t(已乘以输入)
返回: [batch, seq_len, d_model, d_state],所有时刻的隐状态 h_t
"""
batch_size, seq_len, d_model, d_state = delta_A.shape
# 每个时刻的"元素"是 (a_t, b_t) 对
# a_t = Ā_t(状态转移系数),b_t = B̄_t * x_t(输入贡献)
scan_a = delta_A # [batch, seq_len, d_model, d_state]
scan_b = delta_B_x # [batch, seq_len, d_model, d_state]
# 结合操作:(a2, b2) ⊕ (a1, b1) = (a2*a1, a2*b1 + b2)
# 含义:先经历 (a1, b1) 的转移,再经历 (a2, b2) 的转移
# 前缀扫描:用 log2(seq_len) 轮 up-sweep 完成
# 每轮将步长翻倍,并行合并相邻元素对
# 为了在纯 PyTorch 中演示,这里实现 Blelloch 并行前缀扫描
# 实际 Mamba 用 CUDA kernel 实现,效率更高
num_rounds = int(math.ceil(math.log2(seq_len))) if seq_len > 1 else 0
# 用列表存储每轮的中间结果(实际 CUDA 实现在原地操作)
current_a = scan_a.clone()
current_b = scan_b.clone()
# Up-sweep(归约阶段):步长从 1 倍增到 seq_len/2
stride = 1
for _ in range(num_rounds):
# 找到需要合并的位置对:(i - stride, i),i 从 stride 开始,步长 2*stride
left_indices = torch.arange(0, seq_len - stride, 2 * stride, device=delta_A.device)
right_indices = left_indices + stride
if right_indices.numel() == 0:
break
left_a = current_a[:, left_indices] # [batch, n_pairs, d_model, d_state]
left_b = current_b[:, left_indices]
right_a = current_a[:, right_indices]
right_b = current_b[:, right_indices]
# 结合操作:right ⊕ left
merged_a = right_a * left_a
merged_b = right_a * left_b + right_b
current_a[:, right_indices] = merged_a
current_b[:, right_indices] = merged_b
stride *= 2
# 注意:上面的 up-sweep 只得到了部分前缀结果(类似 Blelloch scan 的归约树)
# 完整的 Blelloch scan 还需要 down-sweep 阶段。
# 在实际 Mamba 实现中,使用专门的 CUDA kernel(mamba_ssm 库中的 selective_scan_cuda)
# 直接在 GPU 上高效完成,避免了 Python 层的循环开销。
# 这里为了说明原理,退回到串行递推作为等价的正确实现:
hidden_states = []
h = torch.zeros(batch_size, d_model, d_state, device=delta_A.device, dtype=delta_A.dtype)
for t in range(seq_len):
h = delta_A[:, t] * h + delta_B_x[:, t]
hidden_states.append(h)
return torch.stack(hidden_states, dim=1) # [batch, seq_len, d_model, d_state]
# 验证 parallel_scan_ssm 和 SelectiveSSM 的递推结果一致
def verify_parallel_scan():
torch.manual_seed(0)
batch_size, seq_len, d_model, d_state = 2, 8, 16, 4
delta_A = torch.rand(batch_size, seq_len, d_model, d_state) * 0.9 + 0.05 # (0.05, 0.95)
delta_B_x = torch.randn(batch_size, seq_len, d_model, d_state) * 0.1
h_scan = parallel_scan_ssm(delta_A, delta_B_x)
# 串行递推作为参考
h_ref = torch.zeros(batch_size, d_model, d_state)
h_ref_list = []
for t in range(seq_len):
h_ref = delta_A[:, t] * h_ref + delta_B_x[:, t]
h_ref_list.append(h_ref.clone())
h_ref_stack = torch.stack(h_ref_list, dim=1)
max_diff = (h_scan - h_ref_stack).abs().max().item()
print(f"Parallel Scan 与串行递推最大误差: {max_diff:.2e}") # 应接近 0
assert max_diff < 1e-5, f"结果不一致!误差 {max_diff}"
print("验证通过 ✓")
verify_parallel_scan()
十、Mamba 和 Transformer 的对比
| Transformer | Mamba | |
|---|---|---|
| 核心操作 | Softmax Attention | 选择性 SSM |
| 训练复杂度 | $O(N^2 d)$ | $O(N d)$(Parallel Scan) |
| 推理复杂度(逐步) | $O(Nd)$(有 KV Cache) | $O(d^2)$(固定隐状态) |
| 推理显存 | $O(Nd)$(随序列增长) | $O(d \cdot d_{state})$(固定) |
| 长序列能力 | 受限于 $O(N^2)$ | 天然支持超长序列 |
| 内容感知 | ✅ Attention 天然内容感知 | ✅ 选择性机制实现内容感知 |
| 精确历史访问 | ✅ 能精确 attend 任意历史 token | ❌ 历史被压缩进固定隐状态 |
| 实现复杂度 | 简单 | 纯 PyTorch 可运行;达到生产级速度需要 CUDA kernel(可直接用 mamba-ssm 库) |
十一、回到 Qwen3.6:config.json 里的参数对应
现在你能完全读懂 Qwen3.6 的 linear_attn 参数了:
"linear_attn.A_log" → SSM 的 A 矩阵(对数形式存储,保证稳定性)
"linear_attn.conv1d" → 局部卷积(kernel_size=4,捕获短程特征)
"linear_attn.dt_bias" → Δ(步长)的偏置项,控制离散化粒度
"linear_attn.in_proj_a" → SSM 的 A/dt 相关投影(用于生成 Δ,控制离散化步长)
"linear_attn.in_proj_b" → SSM 的 B 矩阵投影(输入 x 到隐状态 h 的映射)
"linear_attn.in_proj_ba" → B 和 A/dt 的联合权重(合并存储提高访存效率,推理时拆分使用)
"linear_attn.norm" → SSM 内部的归一化层
"linear_conv_kernel_dim": 4 → conv1d 的 kernel_size
"linear_key_head_dim": 128 → 隐状态维度(d_state 的变体)
"linear_num_key_heads": 16 → 多头 SSM 的头数
"linear_num_value_heads": 32 → 输出头数(可以和 key 头数不同)
Qwen3.6 的 linear_attn 是多头 SSM(Multi-head SSM),每个头独立维护一个隐状态,类似 MHA 里每个头独立做 Attention。
十二、核心要点速查
| 问题 | 答案 |
|---|---|
| Mamba 解决什么问题? | Linear Attention 的"遗忘"问题:让衰减因子随输入动态变化 |
| SSM 的递推公式? | $h_t = \bar{A}t h{t-1} + \bar{B}_t x_t$,$y_t = C_t h_t$ |
| $\Delta$(dt)的作用? | 控制"记住"还是"忘记":大 $\Delta$ 重置历史,小 $\Delta$ 跳过当前 |
| A_log 为什么用对数? | 保证 $A < 0$,离散化后 $\bar{A} = e^{\Delta A} \in (0,1)$,系统稳定 |
| conv1d 的作用? | 捕获局部短程特征,补充 SSM 的全局长程建模 |
| 训练时如何并行? | Parallel Scan(并行前缀扫描),$O(N \log N)$ |
| 推理时的优势? | 隐状态固定大小 $O(d \cdot d_{state})$,不随序列增长 |
| 和 Transformer 最大的差距? | 无法精确访问历史 token,只能访问被压缩的隐状态摘要 |
| Qwen3.6 里怎么用的? | 多头 SSM,每 4 层混合一个 Full Attention 补偿精度损失 |