一句话:标准 Attention 的复杂度是 $O(N^2)$,Linear Attention 通过改变计算顺序,把复杂度降到 $O(N)$,代价是用核函数近似替代 softmax,牺牲了一定的表达能力,但换来了推理时的递推形式(类似 RNN),天然支持无限长序列。
一、问题背景:标准 Attention 的瓶颈
标准 Attention 的计算:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$
序列长度为 $N$ 时:
- 计算 $QK^T$ 需要 $O(N^2)$ 时间和空间
- 中间的 $N \times N$ 注意力矩阵需要 $O(N^2)$ 显存
后果:序列长度翻倍,计算量翻 4 倍,显存翻 4 倍。长序列(如 100K token)几乎不可能用标准 Attention。
二、核心思想:改变计算顺序
标准 Attention 的计算顺序
$$\text{out} = \underbrace{\text{softmax}(QK^T)}_{\text{先算这个,}N \times N \text{ 矩阵}} V$$
必须先把完整的 $N \times N$ 矩阵算出来,才能乘以 $V$。
Linear Attention 的关键洞察
如果去掉 softmax,矩阵乘法满足结合律:
$$\text{out} = (QK^T) V = Q \underbrace{(K^T V)}_{\text{先算这个,}d \times d \text{ 矩阵}}$$
先算 $K^T V$(形状 $d \times d$,和序列长度无关),再乘以 $Q$:
标准顺序:Q(N×d) · Kᵀ(d×N) = N×N 矩阵,再乘 V(N×d) → O(N²d)
线性顺序:Kᵀ(d×N) · V(N×d) = d×d 矩阵,Q(N×d) 再乘 → O(Nd²)
当 $d \ll N$ 时(通常 $d=64$,$N=4096$),$O(Nd^2)$ 远小于 $O(N^2 d)$。
为什么不能直接去掉 softmax?
softmax 保证了注意力权重非负且归一化(加起来等于 1),直接去掉会导致数值不稳定、训练发散。
解决方案:用核函数(Kernel Function) $\phi$ 替代 softmax:
$$\text{softmax}(q \cdot k) \approx \phi(q) \cdot \phi(k)$$
常见选择:
- $\phi(x) = \text{ELU}(x) + 1$(Linear Transformer,2020)
- $\phi(x) = e^x$(近似,但数值不稳定)
- $\phi(x) = \text{ReLU}(x)$(简单但效果一般)
三、Linear Attention 的完整公式
$$\text{out}i = \frac{\sum{j=1}^{N} \phi(q_i)^T \phi(k_j) v_j}{\sum_{j=1}^{N} \phi(q_i)^T \phi(k_j)}$$
利用结合律改写(分子):
$$\text{out}i = \frac{\phi(q_i)^T \left(\sum{j=1}^{N} \phi(k_j) v_j^T\right)}{\phi(q_i)^T \left(\sum_{j=1}^{N} \phi(k_j)\right)}$$
令 $S = \sum_{j=1}^{N} \phi(k_j) v_j^T$(形状 $d \times d$),$z = \sum_{j=1}^{N} \phi(k_j)$(形状 $d$):
$$\text{out}_i = \frac{\phi(q_i)^T S}{\phi(q_i)^T z}$$
$S$ 和 $z$ 只需要计算一次,所有 token 共享,复杂度降为 $O(N)$。
四、递推形式:Linear Attention ≈ RNN
Causal(因果)Linear Attention 的最大优势:可以写成递推形式。
对于自回归生成,第 $t$ 步的输出:
$$S_t = S_{t-1} + \phi(k_t) v_t^T$$ $$z_t = z_{t-1} + \phi(k_t)$$ $$\text{out}_t = \frac{\phi(q_t)^T S_t}{\phi(q_t)^T z_t}$$
import torch
import torch.nn.functional as F
def phi(x):
"""核函数:ELU + 1,保证输出非负"""
return F.elu(x) + 1
def linear_attention_recurrent(Q, K, V):
"""
Linear Attention 的递推形式(因果,自回归)。
Q, K, V: [seq_len, d]
返回: [seq_len, d]
"""
seq_len, d = Q.shape
outputs = []
# 递推状态:S 是 d×d 的累积矩阵,z 是 d 维的累积向量
S = torch.zeros(d, d)
z = torch.zeros(d)
for t in range(seq_len):
q_t = phi(Q[t]) # [d]
k_t = phi(K[t]) # [d]
v_t = V[t] # [d]
# 更新状态(累积历史 KV 信息)
S = S + torch.outer(k_t, v_t) # [d, d]
z = z + k_t # [d]
# 计算当前输出
numerator = q_t @ S # [d]
denominator = q_t @ z + 1e-6 # 标量,加 eps 防止除零
outputs.append(numerator / denominator)
return torch.stack(outputs, dim=0) # [seq_len, d]
# 验证
torch.manual_seed(42)
seq_len, d = 8, 16
Q = torch.randn(seq_len, d)
K = torch.randn(seq_len, d)
V = torch.randn(seq_len, d)
out = linear_attention_recurrent(Q, K, V)
print(out.shape) # torch.Size([8, 16])
print("no nan:", not torch.isnan(out).any().item())
这和 RNN 的结构完全一样:
RNN: h_t = f(h_{t-1}, x_t),用隐状态 h 压缩历史
Linear Attention:S_t = S_{t-1} + φ(k_t)vₜᵀ,用矩阵 S 压缩历史
五、并行训练形式
训练时不需要递推,可以并行计算(类似标准 Attention):
def linear_attention_parallel(Q, K, V):
"""
Linear Attention 的并行形式(非因果,训练用)。
Q, K, V: [seq_len, d]
返回: [seq_len, d]
"""
Q_phi = phi(Q) # [seq_len, d]
K_phi = phi(K) # [seq_len, d]
# 先算 KᵀV(d×d),再乘 Q,复杂度 O(Nd²)
KV = K_phi.T @ V # [d, d]
numerator = Q_phi @ KV # [seq_len, d]
# 归一化分母
z = K_phi.sum(dim=0) # [d]
denominator = (Q_phi @ z).unsqueeze(-1) + 1e-6 # [seq_len, 1]
return numerator / denominator # [seq_len, d]
# 验证
out_parallel = linear_attention_parallel(Q, K, V)
print(out_parallel.shape) # torch.Size([8, 16])
print("no nan:", not torch.isnan(out_parallel).any().item())
六、Linear Attention 的问题
问题1:表达能力弱于 softmax Attention
softmax 能产生非常"尖锐"的注意力分布(几乎只关注一个 token),而核函数近似做不到这一点,模型的选择性注意力能力下降。
问题2:隐状态是固定大小的矩阵
S 的形状是 d×d,无论序列多长,S 的大小不变
→ 长序列的历史信息被压缩进固定大小的矩阵
→ 远距离依赖容易被"遗忘"(和 RNN 的梯度消失类似)
这是 Linear Attention 和标准 Attention 最本质的差距:标准 Attention 能精确访问所有历史 token,Linear Attention 只能访问被压缩的历史摘要。
七、改进方向:现代线性注意力模型
RetNet(2023,微软)
引入衰减因子(Decay),让远距离的历史信息自然衰减,缓解"遗忘"问题:
$$S_t = \gamma \cdot S_{t-1} + \phi(k_t) v_t^T, \quad \gamma \in (0, 1)$$
衰减因子让模型更关注近期信息,类似 LSTM 的遗忘门。
Mamba(2023,SSM 架构)
不用核函数近似,而是用状态空间模型(SSM) 的框架,通过选择性机制(Selective SSM)让隐状态能动态决定"记住什么、忘记什么",效果接近 Transformer,推理速度接近 RNN。
GLA(Gated Linear Attention,2024)
在 Linear Attention 的递推公式里加入门控机制:
$$S_t = G_t \odot S_{t-1} + \phi(k_t) v_t^T$$
其中 $G_t$ 是由输入动态生成的门控矩阵,让模型自适应地控制历史信息的保留程度。
八、和标准 Attention / Flash Attention 的对比
| 标准 Attention | Flash Attention | Linear Attention | |
|---|---|---|---|
| 训练复杂度 | $O(N^2 d)$ | $O(N^2 d)$(显存优化) | $O(N d^2)$ |
| 推理复杂度(逐步生成) | $O(N d)$(有KV Cache) | $O(N d)$ | $O(d^2)$(固定!) |
| 推理显存 | $O(Nd)$(KV Cache随序列增长) | $O(Nd)$ | $O(d^2)$(固定!) |
| 表达能力 | 最强(精确 softmax) | 同标准 | 较弱(核函数近似) |
| 长序列支持 | 困难 | 较好 | 天然支持 |
| 是否需要 KV Cache | 需要 | 需要 | 不需要(状态固定大小) |
Linear Attention 推理时最大的优势:KV Cache 不随序列长度增长,固定是 $d \times d$ 的矩阵,推理 1K 和推理 1M token 占用的显存完全一样。
九、核心要点速查
| 问题 | 答案 |
|---|---|
| Linear Attention 解决什么问题? | 把标准 Attention 的 $O(N^2)$ 复杂度降到 $O(N)$ |
| 核心技巧是什么? | 改变矩阵乘法顺序,先算 $K^TV$($d \times d$)再乘 $Q$ |
| 为什么不直接去掉 softmax? | softmax 保证归一化,直接去掉数值不稳定,需要用核函数替代 |
| 递推形式是什么? | $S_t = S_{t-1} + \phi(k_t)v_t^T$,和 RNN 结构一样 |
| 最大的缺点是什么? | 历史信息被压缩进固定大小的矩阵 $S$,远距离依赖容易丢失 |
| 推理时的优势是什么? | 不需要 KV Cache,显存固定为 $O(d^2)$,与序列长度无关 |
| 代表性改进模型? | RetNet(衰减因子)、Mamba(SSM)、GLA(门控) |