一句话:标准 Attention 的复杂度是 $O(N^2)$,Linear Attention 通过改变计算顺序,把复杂度降到 $O(N)$,代价是用核函数近似替代 softmax,牺牲了一定的表达能力,但换来了推理时的递推形式(类似 RNN),天然支持无限长序列。


一、问题背景:标准 Attention 的瓶颈

标准 Attention 的计算:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$

序列长度为 $N$ 时:

  • 计算 $QK^T$ 需要 $O(N^2)$ 时间和空间
  • 中间的 $N \times N$ 注意力矩阵需要 $O(N^2)$ 显存

后果:序列长度翻倍,计算量翻 4 倍,显存翻 4 倍。长序列(如 100K token)几乎不可能用标准 Attention。


二、核心思想:改变计算顺序

标准 Attention 的计算顺序

$$\text{out} = \underbrace{\text{softmax}(QK^T)}_{\text{先算这个,}N \times N \text{ 矩阵}} V$$

必须先把完整的 $N \times N$ 矩阵算出来,才能乘以 $V$。

Linear Attention 的关键洞察

如果去掉 softmax,矩阵乘法满足结合律:

$$\text{out} = (QK^T) V = Q \underbrace{(K^T V)}_{\text{先算这个,}d \times d \text{ 矩阵}}$$

先算 $K^T V$(形状 $d \times d$,和序列长度无关),再乘以 $Q$

标准顺序:Q(N×d) · Kᵀ(d×N) = N×N 矩阵,再乘 V(N×d) → O(N²d)
线性顺序:Kᵀ(d×N) · V(N×d) = d×d 矩阵,Q(N×d) 再乘 → O(Nd²)

当 $d \ll N$ 时(通常 $d=64$,$N=4096$),$O(Nd^2)$ 远小于 $O(N^2 d)$。

为什么不能直接去掉 softmax?

softmax 保证了注意力权重非负且归一化(加起来等于 1),直接去掉会导致数值不稳定、训练发散。

解决方案:用核函数(Kernel Function) $\phi$ 替代 softmax:

$$\text{softmax}(q \cdot k) \approx \phi(q) \cdot \phi(k)$$

常见选择:

  • $\phi(x) = \text{ELU}(x) + 1$(Linear Transformer,2020)
  • $\phi(x) = e^x$(近似,但数值不稳定)
  • $\phi(x) = \text{ReLU}(x)$(简单但效果一般)

三、Linear Attention 的完整公式

$$\text{out}i = \frac{\sum{j=1}^{N} \phi(q_i)^T \phi(k_j) v_j}{\sum_{j=1}^{N} \phi(q_i)^T \phi(k_j)}$$

利用结合律改写(分子):

$$\text{out}i = \frac{\phi(q_i)^T \left(\sum{j=1}^{N} \phi(k_j) v_j^T\right)}{\phi(q_i)^T \left(\sum_{j=1}^{N} \phi(k_j)\right)}$$

令 $S = \sum_{j=1}^{N} \phi(k_j) v_j^T$(形状 $d \times d$),$z = \sum_{j=1}^{N} \phi(k_j)$(形状 $d$):

$$\text{out}_i = \frac{\phi(q_i)^T S}{\phi(q_i)^T z}$$

$S$ 和 $z$ 只需要计算一次,所有 token 共享,复杂度降为 $O(N)$。


四、递推形式:Linear Attention ≈ RNN

Causal(因果)Linear Attention 的最大优势:可以写成递推形式。

对于自回归生成,第 $t$ 步的输出:

$$S_t = S_{t-1} + \phi(k_t) v_t^T$$ $$z_t = z_{t-1} + \phi(k_t)$$ $$\text{out}_t = \frac{\phi(q_t)^T S_t}{\phi(q_t)^T z_t}$$

import torch
import torch.nn.functional as F


def phi(x):
    """核函数:ELU + 1,保证输出非负"""
    return F.elu(x) + 1


def linear_attention_recurrent(Q, K, V):
    """
    Linear Attention 的递推形式(因果,自回归)。
    Q, K, V: [seq_len, d]
    返回: [seq_len, d]
    """
    seq_len, d = Q.shape
    outputs = []

    # 递推状态:S 是 d×d 的累积矩阵,z 是 d 维的累积向量
    S = torch.zeros(d, d)
    z = torch.zeros(d)

    for t in range(seq_len):
        q_t = phi(Q[t])   # [d]
        k_t = phi(K[t])   # [d]
        v_t = V[t]        # [d]

        # 更新状态(累积历史 KV 信息)
        S = S + torch.outer(k_t, v_t)   # [d, d]
        z = z + k_t                      # [d]

        # 计算当前输出
        numerator = q_t @ S              # [d]
        denominator = q_t @ z + 1e-6    # 标量,加 eps 防止除零
        outputs.append(numerator / denominator)

    return torch.stack(outputs, dim=0)  # [seq_len, d]


# 验证
torch.manual_seed(42)
seq_len, d = 8, 16
Q = torch.randn(seq_len, d)
K = torch.randn(seq_len, d)
V = torch.randn(seq_len, d)

out = linear_attention_recurrent(Q, K, V)
print(out.shape)   # torch.Size([8, 16])
print("no nan:", not torch.isnan(out).any().item())

这和 RNN 的结构完全一样

RNN:    h_t = f(h_{t-1}, x_t),用隐状态 h 压缩历史
Linear Attention:S_t = S_{t-1} + φ(k_t)vₜᵀ,用矩阵 S 压缩历史

五、并行训练形式

训练时不需要递推,可以并行计算(类似标准 Attention):

def linear_attention_parallel(Q, K, V):
    """
    Linear Attention 的并行形式(非因果,训练用)。
    Q, K, V: [seq_len, d]
    返回: [seq_len, d]
    """
    Q_phi = phi(Q)   # [seq_len, d]
    K_phi = phi(K)   # [seq_len, d]

    # 先算 KᵀV(d×d),再乘 Q,复杂度 O(Nd²)
    KV = K_phi.T @ V                          # [d, d]
    numerator = Q_phi @ KV                    # [seq_len, d]

    # 归一化分母
    z = K_phi.sum(dim=0)                      # [d]
    denominator = (Q_phi @ z).unsqueeze(-1) + 1e-6  # [seq_len, 1]

    return numerator / denominator            # [seq_len, d]


# 验证
out_parallel = linear_attention_parallel(Q, K, V)
print(out_parallel.shape)   # torch.Size([8, 16])
print("no nan:", not torch.isnan(out_parallel).any().item())

六、Linear Attention 的问题

问题1:表达能力弱于 softmax Attention

softmax 能产生非常"尖锐"的注意力分布(几乎只关注一个 token),而核函数近似做不到这一点,模型的选择性注意力能力下降。

问题2:隐状态是固定大小的矩阵

S 的形状是 d×d,无论序列多长,S 的大小不变
→ 长序列的历史信息被压缩进固定大小的矩阵
→ 远距离依赖容易被"遗忘"(和 RNN 的梯度消失类似)

这是 Linear Attention 和标准 Attention 最本质的差距:标准 Attention 能精确访问所有历史 token,Linear Attention 只能访问被压缩的历史摘要


七、改进方向:现代线性注意力模型

RetNet(2023,微软)

引入衰减因子(Decay),让远距离的历史信息自然衰减,缓解"遗忘"问题:

$$S_t = \gamma \cdot S_{t-1} + \phi(k_t) v_t^T, \quad \gamma \in (0, 1)$$

衰减因子让模型更关注近期信息,类似 LSTM 的遗忘门。

Mamba(2023,SSM 架构)

不用核函数近似,而是用状态空间模型(SSM) 的框架,通过选择性机制(Selective SSM)让隐状态能动态决定"记住什么、忘记什么",效果接近 Transformer,推理速度接近 RNN。

GLA(Gated Linear Attention,2024)

在 Linear Attention 的递推公式里加入门控机制:

$$S_t = G_t \odot S_{t-1} + \phi(k_t) v_t^T$$

其中 $G_t$ 是由输入动态生成的门控矩阵,让模型自适应地控制历史信息的保留程度。


八、和标准 Attention / Flash Attention 的对比

标准 Attention Flash Attention Linear Attention
训练复杂度 $O(N^2 d)$ $O(N^2 d)$(显存优化) $O(N d^2)$
推理复杂度(逐步生成) $O(N d)$(有KV Cache) $O(N d)$ $O(d^2)$(固定!)
推理显存 $O(Nd)$(KV Cache随序列增长) $O(Nd)$ $O(d^2)$(固定!)
表达能力 最强(精确 softmax) 同标准 较弱(核函数近似)
长序列支持 困难 较好 天然支持
是否需要 KV Cache 需要 需要 不需要(状态固定大小)

Linear Attention 推理时最大的优势:KV Cache 不随序列长度增长,固定是 $d \times d$ 的矩阵,推理 1K 和推理 1M token 占用的显存完全一样。


九、核心要点速查

问题 答案
Linear Attention 解决什么问题? 把标准 Attention 的 $O(N^2)$ 复杂度降到 $O(N)$
核心技巧是什么? 改变矩阵乘法顺序,先算 $K^TV$($d \times d$)再乘 $Q$
为什么不直接去掉 softmax? softmax 保证归一化,直接去掉数值不稳定,需要用核函数替代
递推形式是什么? $S_t = S_{t-1} + \phi(k_t)v_t^T$,和 RNN 结构一样
最大的缺点是什么? 历史信息被压缩进固定大小的矩阵 $S$,远距离依赖容易丢失
推理时的优势是什么? 不需要 KV Cache,显存固定为 $O(d^2)$,与序列长度无关
代表性改进模型? RetNet(衰减因子)、Mamba(SSM)、GLA(门控)