前置知识:已了解 Self-Attention(自注意力)。 MHA 的本质:把 Self-Attention 并行做多次,每次关注不同的语义子空间,最后合并结果。
一、为什么需要多头?Self-Attention 有什么不足?
单头 Self-Attention 每次只能学到一种"关注模式"。例如:
句子:"The animal didn't cross the street because it was too tired"
- 头1 可能学到:
it→animal(指代关系) - 头2 可能学到:
tired→animal(状态描述) - 头3 可能学到:
cross→street(动作与地点)
单头只能同时关注一种模式,多头让模型在不同子空间里并行捕捉多种关系。
二、MHA 整体结构
输入 X
│
├──→ 线性投影 W_Q^1 → Q1 ─┐
├──→ 线性投影 W_K^1 → K1 ─┤→ Attention(Q1,K1,V1) → head_1 ─┐
├──→ 线性投影 W_V^1 → V1 ─┘ │
│ │
├──→ 线性投影 W_Q^2 → Q2 ─┐ │
├──→ 线性投影 W_K^2 → K2 ─┤→ Attention(Q2,K2,V2) → head_2 ─┤→ Concat → 线性投影 W_O → 输出
├──→ 线性投影 W_V^2 → V2 ─┘ │
│ │
├──→ ...(共 h 个头) │
│ │
└──→ 线性投影 W_Q^h → Qh ─┐ │
线性投影 W_K^h → Kh ─┤→ Attention(Qh,Kh,Vh) → head_h ─┘
线性投影 W_V^h → Vh ─┘
三、完整公式
单头 Scaled Dot-Product Attention(回顾)
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$
多头注意力
$$ \text{head}_i = \text{Attention}(XW_Q^i,\ XW_K^i,\ XW_V^i) $$
$$ \text{MultiHead}(X) = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_h) \cdot W_O $$
参数说明:
| 符号 | 含义 | 维度 |
|---|---|---|
| $X$ | 输入序列 | $[B, T, d_{model}]$ |
| $W_Q^i, W_K^i, W_V^i$ | 第 $i$ 个头的投影矩阵 | $[d_{model}, d_k]$ |
| $d_k = d_v = d_{model} / h$ | 每个头的维度 | 标量 |
| $W_O$ | 输出投影矩阵 | $[h \cdot d_v, d_{model}]$ |
| $h$ | 头的数量 | 标量(如 8、12、16) |
关键设计:每个头的维度 $d_k = d_{model} / h$,所以多头的总计算量和单头相同,不增加额外开销。
四、逐步拆解计算过程(具体例子)
设定:
- 输入序列长度 $T = 3$(3个词)
- 模型维度 $d_{model} = 4$
- 头数 $h = 2$
- 每个头的维度 $d_k = d_{model} / h = 4 / 2 = 2$
输入 X(3个词,每词4维):
X = [[1, 0, 1, 0], ← 词1
[0, 1, 0, 1], ← 词2
[1, 1, 0, 0]] ← 词3
第一步:线性投影,生成每个头的 Q、K、V
每个头有独立的投影矩阵 $W_Q^i, W_K^i, W_V^i$,维度为 $[4, 2]$。 以下给出头1和头2的全部投影矩阵,后续计算均基于这些固定数值。
头1的投影矩阵:
W_Q^1 = [[1, 0], W_K^1 = [[1, 0], W_V^1 = [[1, 0],
[0, 1], [0, 0], [0, 1],
[1, 0], [0, 1], [0, 0],
[0, 1]] [1, 0]] [1, 0]]
头1的 Q1、K1、V1 计算(X @ W):
X = [[1, 0, 1, 0],
[0, 1, 0, 1],
[1, 1, 0, 0]]
Q1 = X @ W_Q^1:
词1: [1*1+0*0+1*1+0*0, 1*0+0*1+1*0+0*1] = [2, 0]
词2: [0*1+1*0+0*1+1*0, 0*0+1*1+0*0+1*1] = [0, 2]
词3: [1*1+1*0+0*1+0*0, 1*0+1*1+0*0+0*1] = [1, 1]
→ Q1 = [[2, 0], [0, 2], [1, 1]]
K1 = X @ W_K^1:
词1: [1*1+0*0+1*0+0*1, 1*0+0*0+1*1+0*0] = [1, 1]
词2: [0*1+1*0+0*0+1*1, 0*0+1*0+0*1+1*0] = [1, 0]
词3: [1*1+1*0+0*0+0*1, 1*0+1*0+0*1+0*0] = [1, 0]
→ K1 = [[1, 1], [1, 0], [1, 0]]
V1 = X @ W_V^1:
词1: [1*1+0*0+1*0+0*1, 1*0+0*1+1*0+0*0] = [1, 0]
词2: [0*1+1*0+0*0+1*1, 0*0+1*1+0*0+1*0] = [1, 1]
词3: [1*1+1*0+0*0+0*1, 1*0+1*1+0*0+0*0] = [1, 1]
→ V1 = [[1, 0], [1, 1], [1, 1]]
头2的投影矩阵:
W_Q^2 = [[0, 1], W_K^2 = [[0, 1], W_V^2 = [[0, 1],
[1, 0], [1, 0], [1, 0],
[0, 1], [0, 1], [0, 0],
[1, 0]] [1, 0]] [0, 1]]
头2的 Q2、K2、V2 计算:
Q2 = X @ W_Q^2:
词1: [1*0+0*1+1*0+0*1, 1*1+0*0+1*1+0*0] = [0, 2]
词2: [0*0+1*1+0*0+1*1, 0*1+1*0+0*1+1*0] = [2, 0]
词3: [1*0+1*1+0*0+0*1, 1*1+1*0+0*1+0*0] = [1, 1]
→ Q2 = [[0, 2], [2, 0], [1, 1]]
K2 = X @ W_K^2:
词1: [1*0+0*1+1*0+0*1, 1*1+0*0+1*1+0*0] = [0, 2]
词2: [0*0+1*1+0*0+1*1, 0*1+1*0+0*1+1*0] = [2, 0]
词3: [1*0+1*1+0*0+0*1, 1*1+1*0+0*1+0*0] = [1, 1]
→ K2 = [[0, 2], [2, 0], [1, 1]]
V2 = X @ W_V^2:
词1: [1*0+0*1+1*0+0*0, 1*1+0*0+1*0+0*1] = [0, 1]
词2: [0*0+1*1+0*0+1*0, 0*1+1*0+0*0+1*1] = [1, 1]
词3: [1*0+1*1+0*0+0*0, 1*1+1*0+0*0+0*1] = [1, 1]
→ V2 = [[0, 1], [1, 1], [1, 1]]
第二步:每个头独立做 Scaled Dot-Product Attention
头1的计算:
scores1 = Q1 @ K1^T / √2
Q1 @ K1^T(先不除√2):
K1^T = [[1, 1, 1],
[1, 0, 0]]
词1 [2,0] · K1^T: [2*1+0*1, 2*1+0*0, 2*1+0*0] = [2, 2, 2]
词2 [0,2] · K1^T: [0*1+2*1, 0*1+2*0, 0*1+2*0] = [2, 0, 0]
词3 [1,1] · K1^T: [1*1+1*1, 1*1+1*0, 1*1+1*0] = [2, 1, 1]
Q1 @ K1^T = [[2, 2, 2],
[2, 0, 0],
[2, 1, 1]]
除以 √2 ≈ 1.414:
scores1 = [[1.414, 1.414, 1.414],
[1.414, 0.000, 0.000],
[1.414, 0.707, 0.707]]
softmax(对每行独立计算):
词1行 [1.414, 1.414, 1.414]:三个值相等 → softmax = [0.333, 0.333, 0.333]
词2行 [1.414, 0.000, 0.000]:
e^1.414=4.113, e^0=1, e^0=1, 总和=6.113
→ [4.113/6.113, 1/6.113, 1/6.113] = [0.673, 0.164, 0.164]
词3行 [1.414, 0.707, 0.707]:
e^1.414=4.113, e^0.707=2.028, e^0.707=2.028, 总和=8.169
→ [4.113/8.169, 2.028/8.169, 2.028/8.169] = [0.503, 0.248, 0.248]
A1 = [[0.333, 0.333, 0.333],
[0.673, 0.164, 0.164],
[0.503, 0.248, 0.248]]
head_1 = A1 @ V1:
V1 = [[1, 0], [1, 1], [1, 1]]
词1: [0.333*1+0.333*1+0.333*1, 0.333*0+0.333*1+0.333*1] = [1.000, 0.667]
词2: [0.673*1+0.164*1+0.164*1, 0.673*0+0.164*1+0.164*1] = [1.000, 0.327]
词3: [0.503*1+0.248*1+0.248*1, 0.503*0+0.248*1+0.248*1] = [1.000, 0.497]
head_1 = [[1.000, 0.667],
[1.000, 0.327],
[1.000, 0.497]]
头2的计算:
scores2 = Q2 @ K2^T / √2
Q2 @ K2^T(先不除√2):
K2^T = [[0, 2, 1],
[2, 0, 1]]
词1 [0,2] · K2^T: [0*0+2*2, 0*2+2*0, 0*1+2*1] = [4, 0, 2]
词2 [2,0] · K2^T: [2*0+0*2, 2*2+0*0, 2*1+0*1] = [0, 4, 2]
词3 [1,1] · K2^T: [1*0+1*2, 1*2+1*0, 1*1+1*1] = [2, 2, 2]
Q2 @ K2^T = [[4, 0, 2],
[0, 4, 2],
[2, 2, 2]]
除以 √2 ≈ 1.414:
scores2 = [[2.828, 0.000, 1.414],
[0.000, 2.828, 1.414],
[1.414, 1.414, 1.414]]
softmax:
词1行 [2.828, 0.000, 1.414]:
e^2.828=16.93, e^0=1, e^1.414=4.113, 总和=22.04
→ [16.93/22.04, 1/22.04, 4.113/22.04] = [0.768, 0.045, 0.187]
词2行 [0.000, 2.828, 1.414]:
e^0=1, e^2.828=16.93, e^1.414=4.113, 总和=22.04
→ [1/22.04, 16.93/22.04, 4.113/22.04] = [0.045, 0.768, 0.187]
词3行 [1.414, 1.414, 1.414]:三个值相等
→ [0.333, 0.333, 0.333]
A2 = [[0.768, 0.045, 0.187],
[0.045, 0.768, 0.187],
[0.333, 0.333, 0.333]]
head_2 = A2 @ V2:
V2 = [[0, 1], [1, 1], [1, 1]]
词1: [0.768*0+0.045*1+0.187*1, 0.768*1+0.045*1+0.187*1] = [0.232, 1.000]
词2: [0.045*0+0.768*1+0.187*1, 0.045*1+0.768*1+0.187*1] = [0.955, 1.000]
词3: [0.333*0+0.333*1+0.333*1, 0.333*1+0.333*1+0.333*1] = [0.667, 1.000]
head_2 = [[0.232, 1.000],
[0.955, 1.000],
[0.667, 1.000]]
观察:头1的注意力矩阵 A1 中词2强烈关注自己(0.673),头2的 A2 中词1强烈关注自己(0.768)、词2强烈关注自己(0.768)——两个头确实学到了不同的关注模式。
第三步:拼接所有头的输出
head_1 = [[1.000, 0.667], head_2 = [[0.232, 1.000],
[1.000, 0.327], [0.955, 1.000],
[1.000, 0.497]] [0.667, 1.000]]
Concat(head_1, head_2) → 沿最后一维拼接,shape: [3, 4]
Concat = [[1.000, 0.667, 0.232, 1.000], ← 词1
[1.000, 0.327, 0.955, 1.000], ← 词2
[1.000, 0.497, 0.667, 1.000]] ← 词3
第四步:输出线性投影
W_O 的维度:[4, 4](h×d_v → d_model,即 4→4)
取一个具体的 W_O(实际训练中由梯度下降学习得到,这里固定数值演示融合效果):
W_O = [[1, 0, 1, 0],
[0, 1, 0, 1],
[1, 1, 0, 0],
[0, 0, 1, 1]]
Concat = [[1.000, 0.667, 0.232, 1.000],
[1.000, 0.327, 0.955, 1.000],
[1.000, 0.497, 0.667, 1.000]]
输出 = Concat @ W_O,逐行计算:
词1 [1.000, 0.667, 0.232, 1.000]:
列0: 1.000*1 + 0.667*0 + 0.232*1 + 1.000*0 = 1.232
列1: 1.000*0 + 0.667*1 + 0.232*1 + 1.000*0 = 0.899
列2: 1.000*1 + 0.667*0 + 0.232*0 + 1.000*1 = 2.000
列3: 1.000*0 + 0.667*1 + 0.232*0 + 1.000*1 = 1.667
词2 [1.000, 0.327, 0.955, 1.000]:
列0: 1.000*1 + 0.327*0 + 0.955*1 + 1.000*0 = 1.955
列1: 1.000*0 + 0.327*1 + 0.955*1 + 1.000*0 = 1.282
列2: 1.000*1 + 0.327*0 + 0.955*0 + 1.000*1 = 2.000
列3: 1.000*0 + 0.327*1 + 0.955*0 + 1.000*1 = 1.327
词3 [1.000, 0.497, 0.667, 1.000]:
列0: 1.000*1 + 0.497*0 + 0.667*1 + 1.000*0 = 1.667
列1: 1.000*0 + 0.497*1 + 0.667*1 + 1.000*0 = 1.164
列2: 1.000*1 + 0.497*0 + 0.667*0 + 1.000*1 = 2.000
列3: 1.000*0 + 0.497*1 + 0.667*0 + 1.000*1 = 1.497
输出 = [[1.232, 0.899, 2.000, 1.667], ← 词1
[1.955, 1.282, 2.000, 1.327], ← 词2
[1.667, 1.164, 2.000, 1.497]] ← 词3
输出 shape 为 $[3, 4]$,与输入 X 的 shape $[3, 4]$ 完全一致。 W_O 的作用是将两个头的信息线性混合融合:每个输出维度都综合了来自 head_1 和 head_2 的不同特征,而不是简单地保留各头的独立输出。
五、维度变化总览
输入 X: [B, T, d_model] 例:[2, 10, 512]
↓ 每个头独立投影
Q_i, K_i, V_i: [B, T, d_k] 例:[2, 10, 64] (d_k = 512/8 = 64)
↓ Scaled Dot-Product Attention
head_i: [B, T, d_v] 例:[2, 10, 64]
↓ Concat 所有头
Concat: [B, T, h × d_v] 例:[2, 10, 512] (8×64=512,恢复原维度)
↓ 输出投影 W_O
输出: [B, T, d_model] 例:[2, 10, 512]
六、实际模型中的头数配置
| 模型 | $d_{model}$ | 头数 $h$ | 每头维度 $d_k$ |
|---|---|---|---|
| Transformer(原论文) | 512 | 8 | 64 |
| BERT-Base | 768 | 12 | 64 |
| BERT-Large | 1024 | 16 | 64 |
| GPT-3 | 12288 | 96 | 128 |
| LLaMA-7B | 4096 | 32 | 128 |
规律:每头维度 $d_k$ 通常固定在 64 或 128,增大模型主要靠增加头数和 $d_{model}$。
七、MHA 的工程实现技巧
实际代码中,不会真的创建 h 个独立矩阵,而是用一个大矩阵一次性投影,再 reshape 成多头,效率更高。
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int):
super().__init__()
assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # 每个头的维度
# 用一个大矩阵同时投影所有头的 Q、K、V(等价于 h 个独立矩阵)
self.W_Q = nn.Linear(d_model, d_model) # [d_model, d_model]
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model) # 输出投影
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
"""将 [B, T, d_model] reshape 成 [B, h, T, d_k]"""
B, T, _ = x.shape
x = x.view(B, T, self.num_heads, self.d_k)
return x.transpose(1, 2) # [B, h, T, d_k]
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
B, T, _ = x.shape
# 第一步:线性投影,生成 Q、K、V
Q = self.split_heads(self.W_Q(x)) # [B, h, T, d_k]
K = self.split_heads(self.W_K(x)) # [B, h, T, d_k]
V = self.split_heads(self.W_V(x)) # [B, h, T, d_k]
# 第二步:Scaled Dot-Product Attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) # [B, h, T, T]
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1) # [B, h, T, T]
attended = torch.matmul(attention_weights, V) # [B, h, T, d_k]
# 第三步:合并多头
attended = attended.transpose(1, 2).contiguous() # [B, T, h, d_k]
attended = attended.view(B, T, self.d_model) # [B, T, d_model]
# 第四步:输出投影
output = self.W_O(attended) # [B, T, d_model]
return output
# 使用示例
model = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512) # batch=2, 序列长度=10, 维度=512
output = model(x)
print(output.shape) # torch.Size([2, 10, 512])
八、MHA 的变体详解
变体演进的核心驱动力:推理时 KV Cache 占用显存太大、读写太慢。 每生成一个 token,都要把所有历史 token 的 K、V 从显存读出来做 Attention,这是推理速度的主要瓶颈。
8.1 MQA(Multi-Query Attention,2019)
核心改动:所有头共享同一组 K 和 V,只有 Q 保持多头。
MHA:Q1 K1 V1 | Q2 K2 V2 | Q3 K3 V3 | Q4 K4 V4 ← 每头独立 K/V
MQA:Q1 | Q2 | Q3 | Q4
└──────────── K V ────────────┘ ← 所有头共享一组 K/V
KV Cache 变化:
- MHA:每层缓存
2 × h × d_k个值(h 个头各自的 K 和 V) - MQA:每层只缓存
2 × d_k个值(只有 1 组 K/V) - 节省比例:h 倍(如 h=32,节省 32 倍 KV Cache)
代价:K/V 表达能力下降,多个头被迫用同一个 K/V 做 Attention,模型性能有所损失。
代表模型:PaLM、Falcon、早期 Gemini
8.2 GQA(Grouped-Query Attention,2023)
核心改动:把 h 个头分成 g 组,同一组内的头共享一组 K/V。MHA 是 g=h 的特例,MQA 是 g=1 的特例。
MHA(g=h=4):Q1 K1 V1 | Q2 K2 V2 | Q3 K3 V3 | Q4 K4 V4
GQA(g=2): Q1 Q2 | Q3 Q4
└─ K1 V1 ─┘ └─ K2 V2 ─┘ ← 每组共享一对 K/V
MQA(g=1): Q1 Q2 Q3 Q4
└────── K V ──────┘
KV Cache 变化:
- 每层缓存
2 × g × d_k个值 - g 越小,节省越多;g=h 退化为 MHA,g=1 退化为 MQA
优势:在 KV Cache 节省和模型性能之间取得平衡,实践中 g 取 h/4 或 h/8 效果接近 MHA。
代表模型:LLaMA-2/3、Mistral、Qwen2、Gemma2
8.3 MLA(Multi-head Latent Attention,2024,DeepSeek 提出)
核心改动:不直接缓存 K/V,而是缓存一个低秩压缩向量 c,推理时从 c 还原出 K/V。
原理:
标准做法: X → W_K → K(维度 d_k × h) 直接缓存 K
MLA 做法: X → W_下投影 → c(维度 d_c,d_c << d_k × h)→ W_上投影 → K
↑
只缓存这个低维向量 c
KV Cache 变化:
- 标准 MHA:每 token 每层缓存
2 × h × d_k维 - MLA:每 token 每层只缓存
d_c维(DeepSeek-V2 中 d_c = 512,而 2×h×d_k = 2×128×128 = 32768) - 实际节省约 93.3% 的 KV Cache
额外优势:MLA 的表达能力经论文(TransMLA)证明强于 GQA,因为低秩压缩保留了跨头的信息交互,而 GQA 的共享 K/V 是对头的硬性绑定。
代价:推理时需要额外的上投影计算,但因为显存读写量大幅减少,整体推理速度仍然更快。
代表模型:DeepSeek-V2、DeepSeek-V3、DeepSeek-R1
8.4 Flash Attention(2022,工程优化,非结构变体)
重要区分:Flash Attention 不改变 Attention 的数学结果,它是一种 GPU 内存访问的工程优化,和 MHA/GQA/MLA 是正交的概念——可以叠加使用。
问题背景:标准 Attention 计算时,注意力矩阵 scores(shape: [B, h, T, T])需要完整写入 GPU HBM(高带宽内存),当序列长度 T=4096 时,这个矩阵有 4096×4096 = 1600 万个元素,读写极慢。
Flash Attention 的解法:分块(Tiling)计算,把 Q/K/V 分成小块,在 GPU SRAM(片上缓存,速度比 HBM 快 10 倍以上)中完成计算,永远不把完整的注意力矩阵写回 HBM。
标准 Attention 内存访问路径:
HBM → 读 Q,K → SRAM 计算 scores → HBM 写 scores
HBM → 读 scores → SRAM softmax → HBM 写 softmax(scores)
HBM → 读 softmax(scores), V → SRAM 计算输出 → HBM 写输出
(3次大规模 HBM 读写)
Flash Attention 内存访问路径:
分块循环:HBM → 读一小块 Q,K,V → SRAM 内完成所有计算 → HBM 写输出
(只有 1 次 HBM 写,中间结果全在 SRAM)
效果:
- 速度提升 2~4 倍(Flash Attention 2),Flash Attention 3 针对 H100 进一步优化
- 显存占用从 O(T²) 降为 O(T)(不存完整注意力矩阵)
- 支持更长的序列长度
版本演进:
- Flash Attention 1(2022):提出分块计算思路
- Flash Attention 2(2023):优化并行策略,减少非矩阵乘法运算,速度再提升约 2 倍
- Flash Attention 3(2024):针对 Hopper 架构(H100)优化,利用异步流水线,速度达到 H100 理论峰值的 75%
使用现状:几乎所有现代开源大模型训练和推理都默认启用 Flash Attention 2,它和 GQA/MLA 可以同时使用。
8.5 各变体对比总览
| 变体 | Q 头数 | K/V 头数 | KV Cache 大小 | 表达能力 | 代表模型 |
|---|---|---|---|---|---|
| MHA | h | h | 2×h×d_k | ⭐⭐⭐⭐⭐ | BERT、GPT-2、GPT-3 |
| MQA | h | 1 | 2×d_k | ⭐⭐⭐ | PaLM、Falcon |
| GQA | h | g(1<g<h) | 2×g×d_k | ⭐⭐⭐⭐ | LLaMA-2/3、Mistral、Qwen2 |
| MLA | h | 低秩压缩 c | d_c(远小于 2×h×d_k) | ⭐⭐⭐⭐⭐ | DeepSeek-V2/V3/R1 |
| Flash Attention | — | — | 不变(工程优化) | 不变 | 几乎所有现代模型 |
2025 年主流开源模型的选择:
- Meta LLaMA-3:GQA + Flash Attention 2
- Mistral / Mixtral:GQA + Flash Attention 2
- Qwen2 / Qwen2.5:GQA + Flash Attention 2
- DeepSeek-V3 / R1:MLA + Flash Attention 2(MLA 是目前最先进的注意力机制)
- Google Gemma2:GQA + Flash Attention 2
结论:Flash Attention 2 几乎是所有先进开源模型的标配(工程层面),注意力结构层面 GQA 是主流,DeepSeek 的 MLA 是目前表达能力最强、KV Cache 最省的方案,代表最新方向。
九、多层 Transformer 堆叠(Multi-Layer)
多头 vs 多层,两个不同维度
| 多头(Multi-Head) | 多层(Multi-Layer) | |
|---|---|---|
| 方向 | 同一层内并行 | 层与层之间串行 |
| 解决的问题 | 同时捕捉多种语义关系 | 逐层抽象,从低层特征到高层语义 |
| 数据流 | 输入 → 分成 h 份 → 各自 Attention → Concat → 输出 | 第1层输出 → 第2层输入 → … → 第N层输出 |
单个 Transformer Layer 的完整结构
每一层不只有 MHA,还包含 FFN 和残差连接:
输入 X
│
├─→ LayerNorm → MHA → + ←── 残差(直接加上输入 X)
│ │
│ ↓
└─────────────────→ LayerNorm → FFN → + ←── 残差
│
↓
该层输出(送入下一层)
- 残差连接:防止梯度消失,让深层网络可训练(GPT-3 有 96 层,没有残差根本训不动)
- FFN:两层线性变换 + 激活函数,维度先升后降(通常 d_model → 4×d_model → d_model),负责非线性特征变换
- LayerNorm:稳定训练,现代模型多用 Pre-Norm(在 MHA/FFN 之前做归一化)
多层数据流示意
输入 tokens: ["今", "天", "天", "气", "好"]
↓ Embedding + Positional Encoding
X: [B, T, d_model]
━━━━━━━━━━━━━━ 第1层(有自己独立的 W_Q/W_K/W_V/W_O/FFN 参数)━━━━━━━━━━━━━━
MHA(h 个头并行)→ Concat → W_O → 残差 → LayerNorm
FFN → 残差 → LayerNorm
输出 H1: [B, T, d_model] ← shape 不变,内容变了(低层特征:词法、局部语法)
━━━━━━━━━━━━━━ 第2层(参数与第1层完全独立,不共享)━━━━━━━━━━━━━━
输入是 H1,不是原始 X
输出 H2: [B, T, d_model] ← 中层特征:句法结构、指代关系
━━━━━━━━━━━━━━ 第N层 ━━━━━━━━━━━━━━
输出 HN: [B, T, d_model] ← 高层特征:语义、推理、抽象概念
关键:每层 shape 完全相同,但数值不同——这就是多层的本质,维度不变,语义逐层深化。
实际模型的层数配置
| 模型 | 层数 | 头数 | d_model |
|---|---|---|---|
| BERT-Base | 12 | 12 | 768 |
| GPT-2(小) | 12 | 12 | 768 |
| LLaMA-3-8B | 32 | 32 | 4096 |
| LLaMA-3-70B | 80 | 64 | 8192 |
| GPT-3 | 96 | 96 | 12288 |
完整代码:多头 + 多层堆叠
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
B, T, _ = x.shape
return x.view(B, T, self.num_heads, self.d_k).transpose(1, 2)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
B, T, _ = x.shape
Q = self.split_heads(self.W_Q(x))
K = self.split_heads(self.W_K(x))
V = self.split_heads(self.W_V(x))
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = self.dropout(F.softmax(scores, dim=-1))
attended = torch.matmul(weights, V)
attended = attended.transpose(1, 2).contiguous().view(B, T, self.d_model)
return self.W_O(attended)
class FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear2(self.dropout(F.relu(self.linear1(x))))
class TransformerLayer(nn.Module):
"""
单个 Transformer 层 = MHA + FFN + 两个残差连接 + 两个 LayerNorm。
每一层有自己独立的参数,层与层之间不共享。
"""
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.mha = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
# Pre-Norm 风格:先 LayerNorm,再 MHA,再残差
x = x + self.dropout(self.mha(self.norm1(x), mask))
x = x + self.dropout(self.ffn(self.norm2(x)))
return x # shape 不变:[B, T, d_model]
class TransformerEncoder(nn.Module):
"""
多层 Transformer 堆叠。
num_layers 个 TransformerLayer 串行连接,每层参数独立。
"""
def __init__(
self,
num_layers: int,
d_model: int,
num_heads: int,
d_ff: int,
vocab_size: int,
max_seq_len: int,
dropout: float = 0.1,
):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.pos_embedding = nn.Embedding(max_seq_len, d_model)
self.dropout = nn.Dropout(dropout)
# 核心:num_layers 个独立的层,每层参数互不共享
self.layers = nn.ModuleList([
TransformerLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.final_norm = nn.LayerNorm(d_model)
def forward(
self, token_ids: torch.Tensor, mask: torch.Tensor = None
) -> tuple[torch.Tensor, list[torch.Tensor]]:
B, T = token_ids.shape
positions = torch.arange(T, device=token_ids.device).unsqueeze(0)
x = self.dropout(self.token_embedding(token_ids) + self.pos_embedding(positions))
layer_outputs = []
for layer in self.layers:
x = layer(x, mask) # 上一层输出直接作为下一层输入
layer_outputs.append(x.clone()) # 记录每层输出,便于观察
return self.final_norm(x), layer_outputs
# 使用示例
if __name__ == "__main__":
model = TransformerEncoder(
num_layers=4, d_model=64, num_heads=4,
d_ff=256, vocab_size=1000, max_seq_len=32,
)
total_params = sum(p.numel() for p in model.parameters())
print(f"总参数量:{total_params:,},每层参数量:{total_params // 4:,}(4层各自独立)")
token_ids = torch.randint(0, 1000, (2, 10)) # batch=2,序列长度=10
final_output, layer_outputs = model(token_ids)
for i, out in enumerate(layer_outputs):
print(f"第{i+1}层输出 shape: {out.shape} 均值={out.mean():.4f} std={out.std():.4f}")
print(f"最终输出 shape: {final_output.shape}")
# 每层 shape 相同 [2, 10, 64],但数值不同——逐层抽象的体现
十、一句话总结
MHA = 把输入投影到 h 个低维子空间,每个子空间独立做 Self-Attention,捕捉不同语义关系,最后拼接并投影回原维度。
多层堆叠 = 把 MHA + FFN + 残差 + LayerNorm 作为一个 Block,串行堆叠 N 次,每层有独立参数,逐层从低层特征抽象到高层语义。
核心公式记忆:
Q_i, K_i, V_i = X @ W_Q^i, X @ W_K^i, X @ W_V^i
head_i = softmax(Q_i K_i^T / √d_k) · V_i
MHA输出 = Concat(head_1...head_h) @ W_O
单层输出 = LayerNorm(X + FFN(LayerNorm(X + MHA(X))))
多层 = 第(n)层输出 → 第(n+1)层输入,串行 N 次