一句话:Mamba 是 Linear Attention 的"升级版"——同样是 $O(N)$ 复杂度、固定大小隐状态,但通过选择性机制(Selective SSM) 让模型能动态决定"记住什么、忘记什么",效果接近 Transformer,推理速度接近 RNN。


一、从 Linear Attention 到 Mamba:解决什么问题?

回顾 Linear Attention 的递推公式:

$$S_t = S_{t-1} + \phi(k_t) v_t^T$$

问题:衰减是固定的(没有衰减,或者 RetNet 里用固定的 $\gamma$),模型无法根据输入内容动态决定"这个历史信息重不重要"。

类比

RNN(LSTM):有遗忘门,可以选择性地清除历史
Linear Attention:没有遗忘门,历史信息只增不减,全部堆进 S 矩阵
RetNet:有固定衰减 γ,但 γ 是超参数,不随输入变化
Mamba:衰减因子由输入动态生成,每个 token 的"遗忘程度"不同

二、SSM 的数学基础

Mamba 基于状态空间模型(State Space Model,SSM),这是控制论里的经典框架。

连续时间 SSM

$$h’(t) = A h(t) + B x(t)$$ $$y(t) = C h(t)$$

  • $x(t)$:输入信号
  • $h(t)$:隐状态(类比 RNN 的 hidden state)
  • $y(t)$:输出
  • $A$:状态转移矩阵(控制历史信息如何演化)
  • $B$:输入投影矩阵(控制输入如何影响隐状态)
  • $C$:输出投影矩阵(控制隐状态如何映射到输出)

离散化(实际使用的形式)

连续 SSM 需要离散化才能用于序列建模,使用零阶保持(ZOH) 方法:

$$\bar{A} = e^{\Delta A}, \quad \bar{B} = (e^{\Delta A} - I) A^{-1} B \approx \Delta B$$

其中 $\Delta$(dt_bias 对应的参数)是步长(step size),控制离散化的粒度。

离散化后的递推公式:

$$h_t = \bar{A} h_{t-1} + \bar{B} x_t$$ $$y_t = C h_t$$

这和 RNN 的结构完全一样!

RNN:   h_t = tanh(W_h h_{t-1} + W_x x_t)
SSM:   h_t = Ā h_{t-1} + B̄ x_t
区别:  SSM 的 Ā 有特殊结构(来自连续系统的离散化),更有理论保证

三、S4:Mamba 的前身

S4(Structured State Space Sequence Model,2021) 是 Mamba 的直接前身。

S4 的关键设计:把 $A$ 矩阵限制为对角加低秩(DPLR) 结构,使得:

  1. 可以高效并行计算(卷积形式)
  2. 可以高效递推(RNN 形式)
  3. 理论上能捕获超长距离依赖

S4 的问题:$A$、$B$、$C$ 都是固定参数,不随输入变化——内容无关(content-unaware)

S4 处理"今天天气很好"和"今天天气很差"时,
用的是完全相同的状态转移矩阵 Ā,
模型无法根据"好"还是"差"来决定记住多少。

四、Mamba 的核心创新:选择性机制

Mamba(2023,Albert Gu & Tri Dao)的核心贡献:让 $B$、$C$、$\Delta$ 依赖于输入 $x_t$

对比 S4 和 Mamba

参数 S4 Mamba
$A$ 固定(训练后不变) 固定(但结构特殊)
$B$ 固定 由 $x_t$ 动态生成
$C$ 固定 由 $x_t$ 动态生成
$\Delta$ 固定 由 $x_t$ 动态生成

选择性机制的直觉

$$B_t = \text{Linear}_B(x_t), \quad C_t = \text{Linear}C(x_t), \quad \Delta_t = \text{softplus}(\text{Linear}\Delta(x_t))$$

$$\bar{A}_t = e^{\Delta_t A}, \quad \bar{B}_t = \Delta_t B_t$$

$$h_t = \bar{A}t h{t-1} + \bar{B}_t x_t$$ $$y_t = C_t h_t$$

$\Delta_t$ 的作用(这就是 config.json 里的 dt_bias):

Δ_t 很大 → Ā_t ≈ 0,B̄_t ≈ B_t
         → h_t ≈ B_t x_t(几乎忘掉历史,专注当前输入)
         → 相当于"重置"隐状态

Δ_t 很小 → Ā_t ≈ I,B̄_t ≈ 0
         → h_t ≈ h_{t-1}(几乎忽略当前输入,保留历史)
         → 相当于"跳过"当前 token

这就是 Mamba 的选择性:模型学会了对重要 token 用大 $\Delta$(重置并记住),对不重要 token 用小 $\Delta$(直接跳过)。


五、A 矩阵:HiPPO 初始化

$A$ 矩阵虽然固定,但初始化方式很关键。Mamba 使用 HiPPO(High-order Polynomial Projection Operators) 初始化:

$$A_{nk} = -\begin{cases} (2n+1)^{1/2}(2k+1)^{1/2} & n > k \ n+1 & n = k \ 0 & n < k \end{cases}$$

直觉:HiPPO 矩阵被设计为能最优地压缩历史信息——隐状态 $h$ 相当于对历史输入的多项式近似系数,理论上能记住任意长距离的依赖。

这就是 config.json 里 linear_attn.A_log 的来源:$A$ 以对数形式存储(A_log = log(-A)),保证离散化后 $\bar{A}$ 的特征值在单位圆内(系统稳定)。


六、conv1d 的作用

config.json 里还有 linear_attn.conv1d,这是 Mamba 的另一个设计:

在 SSM 之前,先做一个短程卷积

x_t → conv1d(kernel_size=4)→ SSM → y_t

为什么需要 conv1d?

SSM 的隐状态是全局的(压缩了所有历史),但对局部特征(如 n-gram、短语结构)不敏感。conv1d 用小卷积核(kernel_size=4,对应 config 里的 linear_conv_kernel_dim: 4)捕获局部模式,作为 SSM 的补充。

conv1d:捕获局部特征(短程)
SSM:   捕获全局依赖(长程)
两者结合:覆盖所有尺度的依赖

七、Mamba Block 的完整结构

输入 x [seq_len, d_model]
        │
        ├─────────────────────────────┐
        │                             │
   Linear(d_model → d_inner)    Linear(d_model → d_inner)
        │                             │
      SiLU                            │
        │                             │
   conv1d(kernel=4)                   │
        │                             │
      SiLU                            │
        │                             │
   SSM(选择性状态空间)               │
   ┌─────────────────────┐            │
   │ B_t = Linear_B(x_t) │            │
   │ C_t = Linear_C(x_t) │            │
   │ Δ_t = Linear_Δ(x_t) │            │
   │ h_t = Ā_t h_{t-1} + B̄_t x_t │   │
   │ y_t = C_t h_t       │            │
   └─────────────────────┘            │
        │                             │
        └──────── × ──────────────────┘
                  │(门控:SSM输出 × 线性分支)
                  │
           Linear(d_inner → d_model)
                  │
               输出 y

门控设计:右边的线性分支类似 SwiGLU 里的门控,让模型能选择性地"放大"或"抑制" SSM 的输出。


八、完整代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SelectiveSSM(nn.Module):
    """
    Mamba 的核心:选择性状态空间模型(Selective SSM)。
    d_model: 输入维度
    d_state: 隐状态维度(论文中通常为 16)
    """
    def __init__(self, d_model: int, d_state: int = 16):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state

        # A 矩阵:以对数形式存储,保证离散化后稳定
        # 初始化为 log(1, 2, ..., d_state),近似 HiPPO
        A = torch.arange(1, d_state + 1, dtype=torch.float32).unsqueeze(0).expand(d_model, -1)
        self.A_log = nn.Parameter(torch.log(A))  # [d_model, d_state]

        # B、C、Δ 由输入动态生成
        self.linear_B = nn.Linear(d_model, d_state, bias=False)
        self.linear_C = nn.Linear(d_model, d_state, bias=False)
        self.linear_delta = nn.Linear(d_model, d_model, bias=True)  # dt_bias 在这里

        # D:跳跃连接(直接将输入加到输出)
        self.D = nn.Parameter(torch.ones(d_model))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [batch, seq_len, d_model]
        返回: [batch, seq_len, d_model]
        """
        batch_size, seq_len, d_model = x.shape
        d_state = self.d_state

        # 恢复 A(负数,保证系统稳定)
        A = -torch.exp(self.A_log)  # [d_model, d_state]

        # 动态生成 B、C、Δ
        B = self.linear_B(x)                          # [batch, seq_len, d_state]
        C = self.linear_C(x)                          # [batch, seq_len, d_state]
        delta = F.softplus(self.linear_delta(x))      # [batch, seq_len, d_model],保证 Δ > 0

        # 离散化(ZOH,零阶保持):
        #   Ā = exp(Δ·A)
        #   B̄ = (ΔA)⁻¹ · (exp(ΔA) - I) · ΔB
        # 因为 A 是对角矩阵(每个 d_model 维度独立对应 d_state 个状态),
        # (ΔA)⁻¹ · (exp(ΔA) - I) 可以逐元素计算,无需矩阵求逆:
        #   B̄ = (exp(ΔA) - 1) / A · B = (Ā - 1) / A · B
        #
        # delta:   [batch, seq_len, d_model]
        # A:       [d_model, d_state]
        # delta_A: [batch, seq_len, d_model, d_state]
        delta_A_product = delta.unsqueeze(-1) * A          # ΔA,[batch, seq_len, d_model, d_state]
        delta_A = torch.exp(delta_A_product)               # Ā = exp(ΔA),[batch, seq_len, d_model, d_state]

        # ZOH 的 B̄:逐元素 (exp(ΔA) - 1) / A · B
        # 当 A 接近 0 时用 Taylor 展开近似(数值稳定),但实际中 A 初始化为负整数,不会为 0
        delta_B_zoh = (delta_A - 1.0) / A                 # (exp(ΔA) - 1) / A,[batch, seq_len, d_model, d_state]
        delta_B = delta_B_zoh * B.unsqueeze(2)             # B̄ = delta_B_zoh · B,[batch, seq_len, d_model, d_state]

        # 递推计算隐状态
        # h: [batch, d_model, d_state]
        h = torch.zeros(batch_size, d_model, d_state, device=x.device)
        outputs = []

        for t in range(seq_len):
            # h_t = Ā_t ⊙ h_{t-1} + B̄_t ⊙ x_t
            h = delta_A[:, t] * h + delta_B[:, t] * x[:, t].unsqueeze(-1)
            # y_t = C_t · h_t(对 d_state 维度求和)
            y_t = (C[:, t].unsqueeze(1) * h).sum(dim=-1)  # [batch, d_model]
            outputs.append(y_t)

        y = torch.stack(outputs, dim=1)  # [batch, seq_len, d_model]

        # 跳跃连接:y += D * x
        y = y + self.D * x

        return y


class MambaBlock(nn.Module):
    """
    完整的 Mamba Block,包含门控结构和 conv1d。
    d_model: 模型维度
    d_inner: 内部扩展维度(通常为 d_model * 2)
    d_state: SSM 隐状态维度
    conv_kernel: 局部卷积核大小
    """
    def __init__(self, d_model: int, d_inner: int = None, d_state: int = 16, conv_kernel: int = 4):
        super().__init__()
        self.d_model = d_model
        self.d_inner = d_inner or d_model * 2

        self.norm = nn.LayerNorm(d_model)

        # 输入投影:分成两路(SSM 路 + 门控路)
        self.input_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)

        # 局部卷积(捕获短程特征)
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            kernel_size=conv_kernel,
            padding=conv_kernel - 1,  # causal padding
            groups=self.d_inner,      # depthwise conv
            bias=True
        )

        # 选择性 SSM
        self.ssm = SelectiveSSM(self.d_inner, d_state)

        # 输出投影
        self.output_proj = nn.Linear(self.d_inner, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [batch, seq_len, d_model]
        返回: [batch, seq_len, d_model](残差连接在外部处理)
        """
        residual = x
        x = self.norm(x)

        # 分成两路
        projected = self.input_proj(x)                     # [batch, seq_len, d_inner*2]
        ssm_branch, gate_branch = projected.chunk(2, dim=-1)  # 各 [batch, seq_len, d_inner]

        # SSM 路:conv1d → SiLU → SSM
        # conv1d 需要 [batch, channels, seq_len] 格式
        ssm_branch = ssm_branch.transpose(1, 2)            # [batch, d_inner, seq_len]
        ssm_branch = self.conv1d(ssm_branch)[..., :x.shape[1]]  # causal: 截掉多余的 padding
        ssm_branch = ssm_branch.transpose(1, 2)            # [batch, seq_len, d_inner]
        ssm_branch = F.silu(ssm_branch)
        ssm_branch = self.ssm(ssm_branch)                  # [batch, seq_len, d_inner]

        # 门控路:SiLU
        gate_branch = F.silu(gate_branch)

        # 门控融合
        output = ssm_branch * gate_branch                  # [batch, seq_len, d_inner]

        # 输出投影 + 残差
        output = self.output_proj(output)                  # [batch, seq_len, d_model]
        return output + residual


# ── 验证 ──────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    torch.manual_seed(42)
    batch_size, seq_len, d_model = 2, 16, 64

    block = MambaBlock(d_model=d_model, d_inner=128, d_state=16, conv_kernel=4)
    x = torch.randn(batch_size, seq_len, d_model)
    y = block(x)

    print(f"输入形状:  {x.shape}")   # [2, 16, 64]
    print(f"输出形状:  {y.shape}")   # [2, 16, 64]
    print(f"无 NaN:   {not torch.isnan(y).any().item()}")

    # 验证推理时的递推(单步生成)
    ssm = SelectiveSSM(d_model=64, d_state=16)
    single_token = torch.randn(1, 1, 64)
    single_out = ssm(single_token)
    print(f"单步推理: {single_out.shape}")  # [1, 1, 64]

九、训练时的并行化:Parallel Scan

递推形式在训练时是串行的,效率低。Mamba 的另一个贡献:当 $\bar{A}$ 不依赖输入时(S4),可以用卷积并行计算

对于固定的 $\bar{A}$,展开递推:

$$h_t = \bar{A}^t h_0 + \sum_{i=0}^{t} \bar{A}^{t-i} \bar{B} x_i$$

$$y_t = C h_t = \sum_{i=0}^{t} \underbrace{C \bar{A}^{t-i} \bar{B}}{\text{卷积核 } K{t-i}} x_i$$

这就是一个因果卷积!可以用 FFT 在 $O(N \log N)$ 时间内并行计算。

Mamba 的选择性机制打破了这个并行化(因为 $\bar{A}_t$ 依赖输入,不再固定),所以 Mamba 使用了 Parallel Scan(并行前缀扫描) 算法。

Parallel Scan 的核心思想

SSM 的递推 $h_t = \bar{A}t h{t-1} + \bar{B}_t x_t$ 是一个结合性(associative) 操作,可以用分治法并行化:

把相邻两步的递推合并成一步:

$$\begin{pmatrix} h_t \ 1 \end{pmatrix} = \begin{pmatrix} \bar{A}_t & \bar{B}t x_t \ 0 & 1 \end{pmatrix} \begin{pmatrix} h{t-1} \ 1 \end{pmatrix}$$

令 $e_t = (\bar{A}_t,\ \bar{B}_t x_t)$ 为第 $t$ 步的"元素",定义结合操作 $\oplus$:

$$(a_2, b_2) \oplus (a_1, b_1) = (a_2 \cdot a_1,\ a_2 \cdot b_1 + b_2)$$

则 $h_t$ 等价于对 $e_1, e_2, \ldots, e_t$ 做前缀扫描(prefix scan)。前缀扫描可以用 $O(\log N)$ 轮并行归约完成,总复杂度 $O(N \log N)$,在 GPU 上每轮内部完全并行。

def parallel_scan_ssm(delta_A: torch.Tensor, delta_B_x: torch.Tensor) -> torch.Tensor:
    """
    用 Parallel Scan(并行前缀扫描)计算 SSM 的所有隐状态。

    核心思想:把 h_t = A_t * h_{t-1} + b_t 的串行递推,
    转化为对 (A_t, b_t) 对的结合性前缀扫描,实现并行计算。

    delta_A:   [batch, seq_len, d_model, d_state],离散化后的 Ā_t
    delta_B_x: [batch, seq_len, d_model, d_state],B̄_t * x_t(已乘以输入)
    返回:      [batch, seq_len, d_model, d_state],所有时刻的隐状态 h_t
    """
    batch_size, seq_len, d_model, d_state = delta_A.shape

    # 每个时刻的"元素"是 (a_t, b_t) 对
    # a_t = Ā_t(状态转移系数),b_t = B̄_t * x_t(输入贡献)
    scan_a = delta_A        # [batch, seq_len, d_model, d_state]
    scan_b = delta_B_x      # [batch, seq_len, d_model, d_state]

    # 结合操作:(a2, b2) ⊕ (a1, b1) = (a2*a1, a2*b1 + b2)
    # 含义:先经历 (a1, b1) 的转移,再经历 (a2, b2) 的转移
    # 前缀扫描:用 log2(seq_len) 轮 up-sweep 完成
    # 每轮将步长翻倍,并行合并相邻元素对

    # 为了在纯 PyTorch 中演示,这里实现 Blelloch 并行前缀扫描
    # 实际 Mamba 用 CUDA kernel 实现,效率更高
    num_rounds = int(math.ceil(math.log2(seq_len))) if seq_len > 1 else 0

    # 用列表存储每轮的中间结果(实际 CUDA 实现在原地操作)
    current_a = scan_a.clone()
    current_b = scan_b.clone()

    # Up-sweep(归约阶段):步长从 1 倍增到 seq_len/2
    stride = 1
    for _ in range(num_rounds):
        # 找到需要合并的位置对:(i - stride, i),i 从 stride 开始,步长 2*stride
        left_indices = torch.arange(0, seq_len - stride, 2 * stride, device=delta_A.device)
        right_indices = left_indices + stride

        if right_indices.numel() == 0:
            break

        left_a = current_a[:, left_indices]   # [batch, n_pairs, d_model, d_state]
        left_b = current_b[:, left_indices]
        right_a = current_a[:, right_indices]
        right_b = current_b[:, right_indices]

        # 结合操作:right ⊕ left
        merged_a = right_a * left_a
        merged_b = right_a * left_b + right_b

        current_a[:, right_indices] = merged_a
        current_b[:, right_indices] = merged_b

        stride *= 2

    # 注意:上面的 up-sweep 只得到了部分前缀结果(类似 Blelloch scan 的归约树)
    # 完整的 Blelloch scan 还需要 down-sweep 阶段。
    # 在实际 Mamba 实现中,使用专门的 CUDA kernel(mamba_ssm 库中的 selective_scan_cuda)
    # 直接在 GPU 上高效完成,避免了 Python 层的循环开销。
    # 这里为了说明原理,退回到串行递推作为等价的正确实现:
    hidden_states = []
    h = torch.zeros(batch_size, d_model, d_state, device=delta_A.device, dtype=delta_A.dtype)
    for t in range(seq_len):
        h = delta_A[:, t] * h + delta_B_x[:, t]
        hidden_states.append(h)

    return torch.stack(hidden_states, dim=1)  # [batch, seq_len, d_model, d_state]


# 验证 parallel_scan_ssm 和 SelectiveSSM 的递推结果一致
def verify_parallel_scan():
    torch.manual_seed(0)
    batch_size, seq_len, d_model, d_state = 2, 8, 16, 4

    delta_A = torch.rand(batch_size, seq_len, d_model, d_state) * 0.9 + 0.05  # (0.05, 0.95)
    delta_B_x = torch.randn(batch_size, seq_len, d_model, d_state) * 0.1

    h_scan = parallel_scan_ssm(delta_A, delta_B_x)

    # 串行递推作为参考
    h_ref = torch.zeros(batch_size, d_model, d_state)
    h_ref_list = []
    for t in range(seq_len):
        h_ref = delta_A[:, t] * h_ref + delta_B_x[:, t]
        h_ref_list.append(h_ref.clone())
    h_ref_stack = torch.stack(h_ref_list, dim=1)

    max_diff = (h_scan - h_ref_stack).abs().max().item()
    print(f"Parallel Scan 与串行递推最大误差: {max_diff:.2e}")  # 应接近 0
    assert max_diff < 1e-5, f"结果不一致!误差 {max_diff}"
    print("验证通过 ✓")

verify_parallel_scan()

十、Mamba 和 Transformer 的对比

Transformer Mamba
核心操作 Softmax Attention 选择性 SSM
训练复杂度 $O(N^2 d)$ $O(N d)$(Parallel Scan)
推理复杂度(逐步) $O(Nd)$(有 KV Cache) $O(d^2)$(固定隐状态)
推理显存 $O(Nd)$(随序列增长) $O(d \cdot d_{state})$(固定
长序列能力 受限于 $O(N^2)$ 天然支持超长序列
内容感知 ✅ Attention 天然内容感知 ✅ 选择性机制实现内容感知
精确历史访问 ✅ 能精确 attend 任意历史 token ❌ 历史被压缩进固定隐状态
实现复杂度 简单 纯 PyTorch 可运行;达到生产级速度需要 CUDA kernel(可直接用 mamba-ssm 库)

十一、回到 Qwen3.6:config.json 里的参数对应

现在你能完全读懂 Qwen3.6 的 linear_attn 参数了:

"linear_attn.A_log"       SSM  A 矩阵(对数形式存储,保证稳定性)
"linear_attn.conv1d"      局部卷积(kernel_size=4,捕获短程特征)
"linear_attn.dt_bias"     Δ(步长)的偏置项,控制离散化粒度
"linear_attn.in_proj_a"   SSM  A/dt 相关投影(用于生成 Δ,控制离散化步长)
"linear_attn.in_proj_b"   SSM  B 矩阵投影(输入 x 到隐状态 h 的映射)
"linear_attn.in_proj_ba"  B  A/dt 的联合权重(合并存储提高访存效率,推理时拆分使用)
"linear_attn.norm"        SSM 内部的归一化层

"linear_conv_kernel_dim": 4    conv1d  kernel_size
"linear_key_head_dim": 128     隐状态维度(d_state 的变体)
"linear_num_key_heads": 16     多头 SSM 的头数
"linear_num_value_heads": 32   输出头数(可以和 key 头数不同)

Qwen3.6 的 linear_attn多头 SSM(Multi-head SSM),每个头独立维护一个隐状态,类似 MHA 里每个头独立做 Attention。


十二、核心要点速查

问题 答案
Mamba 解决什么问题? Linear Attention 的"遗忘"问题:让衰减因子随输入动态变化
SSM 的递推公式? $h_t = \bar{A}t h{t-1} + \bar{B}_t x_t$,$y_t = C_t h_t$
$\Delta$(dt)的作用? 控制"记住"还是"忘记":大 $\Delta$ 重置历史,小 $\Delta$ 跳过当前
A_log 为什么用对数? 保证 $A < 0$,离散化后 $\bar{A} = e^{\Delta A} \in (0,1)$,系统稳定
conv1d 的作用? 捕获局部短程特征,补充 SSM 的全局长程建模
训练时如何并行? Parallel Scan(并行前缀扫描),$O(N \log N)$
推理时的优势? 隐状态固定大小 $O(d \cdot d_{state})$,不随序列增长
和 Transformer 最大的差距? 无法精确访问历史 token,只能访问被压缩的隐状态摘要
Qwen3.6 里怎么用的? 多头 SSM,每 4 层混合一个 Full Attention 补偿精度损失