前置知识:已了解 Self-Attention(自注意力)。 MHA 的本质:把 Self-Attention 并行做多次,每次关注不同的语义子空间,最后合并结果。


一、为什么需要多头?Self-Attention 有什么不足?

单头 Self-Attention 每次只能学到一种"关注模式"。例如:

句子:"The animal didn't cross the street because it was too tired"
  • 头1 可能学到:itanimal(指代关系)
  • 头2 可能学到:tiredanimal(状态描述)
  • 头3 可能学到:crossstreet(动作与地点)

单头只能同时关注一种模式,多头让模型在不同子空间里并行捕捉多种关系。


二、MHA 整体结构

输入 X
  │
  ├──→ 线性投影 W_Q^1 → Q1 ─┐
  ├──→ 线性投影 W_K^1 → K1 ─┤→ Attention(Q1,K1,V1) → head_1 ─┐
  ├──→ 线性投影 W_V^1 → V1 ─┘                                   │
  │                                                              │
  ├──→ 线性投影 W_Q^2 → Q2 ─┐                                   │
  ├──→ 线性投影 W_K^2 → K2 ─┤→ Attention(Q2,K2,V2) → head_2 ─┤→ Concat → 线性投影 W_O → 输出
  ├──→ 线性投影 W_V^2 → V2 ─┘                                   │
  │                                                              │
  ├──→ ...(共 h 个头)                                          │
  │                                                              │
  └──→ 线性投影 W_Q^h → Qh ─┐                                   │
      线性投影 W_K^h → Kh ─┤→ Attention(Qh,Kh,Vh) → head_h ─┘
      线性投影 W_V^h → Vh ─┘

三、完整公式

单头 Scaled Dot-Product Attention(回顾)

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$

多头注意力

$$ \text{head}_i = \text{Attention}(XW_Q^i,\ XW_K^i,\ XW_V^i) $$

$$ \text{MultiHead}(X) = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_h) \cdot W_O $$

参数说明:

符号 含义 维度
$X$ 输入序列 $[B, T, d_{model}]$
$W_Q^i, W_K^i, W_V^i$ 第 $i$ 个头的投影矩阵 $[d_{model}, d_k]$
$d_k = d_v = d_{model} / h$ 每个头的维度 标量
$W_O$ 输出投影矩阵 $[h \cdot d_v, d_{model}]$
$h$ 头的数量 标量(如 8、12、16)

关键设计:每个头的维度 $d_k = d_{model} / h$,所以多头的总计算量和单头相同,不增加额外开销。


四、逐步拆解计算过程(具体例子)

设定:

  • 输入序列长度 $T = 3$(3个词)
  • 模型维度 $d_{model} = 4$
  • 头数 $h = 2$
  • 每个头的维度 $d_k = d_{model} / h = 4 / 2 = 2$

输入 X(3个词,每词4维):

X = [[1, 0, 1, 0],   ← 词1
     [0, 1, 0, 1],   ← 词2
     [1, 1, 0, 0]]   ← 词3

第一步:线性投影,生成每个头的 Q、K、V

每个头有独立的投影矩阵 $W_Q^i, W_K^i, W_V^i$,维度为 $[4, 2]$。 以下给出头1和头2的全部投影矩阵,后续计算均基于这些固定数值。

头1的投影矩阵:

W_Q^1 = [[1, 0],    W_K^1 = [[1, 0],    W_V^1 = [[1, 0],
          [0, 1],              [0, 0],              [0, 1],
          [1, 0],              [0, 1],              [0, 0],
          [0, 1]]              [1, 0]]              [1, 0]]

头1的 Q1、K1、V1 计算(X @ W):

X = [[1, 0, 1, 0],
     [0, 1, 0, 1],
     [1, 1, 0, 0]]

Q1 = X @ W_Q^1:
  词1: [1*1+0*0+1*1+0*0, 1*0+0*1+1*0+0*1] = [2, 0]
  词2: [0*1+1*0+0*1+1*0, 0*0+1*1+0*0+1*1] = [0, 2]
  词3: [1*1+1*0+0*1+0*0, 1*0+1*1+0*0+0*1] = [1, 1]
  → Q1 = [[2, 0], [0, 2], [1, 1]]

K1 = X @ W_K^1:
  词1: [1*1+0*0+1*0+0*1, 1*0+0*0+1*1+0*0] = [1, 1]
  词2: [0*1+1*0+0*0+1*1, 0*0+1*0+0*1+1*0] = [1, 0]
  词3: [1*1+1*0+0*0+0*1, 1*0+1*0+0*1+0*0] = [1, 0]
  → K1 = [[1, 1], [1, 0], [1, 0]]

V1 = X @ W_V^1:
  词1: [1*1+0*0+1*0+0*1, 1*0+0*1+1*0+0*0] = [1, 0]
  词2: [0*1+1*0+0*0+1*1, 0*0+1*1+0*0+1*0] = [1, 1]
  词3: [1*1+1*0+0*0+0*1, 1*0+1*1+0*0+0*0] = [1, 1]
  → V1 = [[1, 0], [1, 1], [1, 1]]

头2的投影矩阵:

W_Q^2 = [[0, 1],    W_K^2 = [[0, 1],    W_V^2 = [[0, 1],
          [1, 0],              [1, 0],              [1, 0],
          [0, 1],              [0, 1],              [0, 0],
          [1, 0]]              [1, 0]]              [0, 1]]

头2的 Q2、K2、V2 计算:

Q2 = X @ W_Q^2:
  词1: [1*0+0*1+1*0+0*1, 1*1+0*0+1*1+0*0] = [0, 2]
  词2: [0*0+1*1+0*0+1*1, 0*1+1*0+0*1+1*0] = [2, 0]
  词3: [1*0+1*1+0*0+0*1, 1*1+1*0+0*1+0*0] = [1, 1]
  → Q2 = [[0, 2], [2, 0], [1, 1]]

K2 = X @ W_K^2:
  词1: [1*0+0*1+1*0+0*1, 1*1+0*0+1*1+0*0] = [0, 2]
  词2: [0*0+1*1+0*0+1*1, 0*1+1*0+0*1+1*0] = [2, 0]
  词3: [1*0+1*1+0*0+0*1, 1*1+1*0+0*1+0*0] = [1, 1]
  → K2 = [[0, 2], [2, 0], [1, 1]]

V2 = X @ W_V^2:
  词1: [1*0+0*1+1*0+0*0, 1*1+0*0+1*0+0*1] = [0, 1]
  词2: [0*0+1*1+0*0+1*0, 0*1+1*0+0*0+1*1] = [1, 1]
  词3: [1*0+1*1+0*0+0*0, 1*1+1*0+0*0+0*1] = [1, 1]
  → V2 = [[0, 1], [1, 1], [1, 1]]

第二步:每个头独立做 Scaled Dot-Product Attention

头1的计算:

scores1 = Q1 @ K1^T / √2

Q1 @ K1^T(先不除√2):
  K1^T = [[1, 1, 1],
           [1, 0, 0]]

  词1 [2,0] · K1^T: [2*1+0*1, 2*1+0*0, 2*1+0*0] = [2, 2, 2]
  词2 [0,2] · K1^T: [0*1+2*1, 0*1+2*0, 0*1+2*0] = [2, 0, 0]
  词3 [1,1] · K1^T: [1*1+1*1, 1*1+1*0, 1*1+1*0] = [2, 1, 1]

  Q1 @ K1^T = [[2, 2, 2],
               [2, 0, 0],
               [2, 1, 1]]

除以 √2 ≈ 1.414:
  scores1 = [[1.414, 1.414, 1.414],
             [1.414, 0.000, 0.000],
             [1.414, 0.707, 0.707]]

softmax(对每行独立计算):
  词1行 [1.414, 1.414, 1.414]:三个值相等 → softmax = [0.333, 0.333, 0.333]
  词2行 [1.414, 0.000, 0.000]:
    e^1.414=4.113, e^0=1, e^0=1, 总和=6.113
    → [4.113/6.113, 1/6.113, 1/6.113] = [0.673, 0.164, 0.164]
  词3行 [1.414, 0.707, 0.707]:
    e^1.414=4.113, e^0.707=2.028, e^0.707=2.028, 总和=8.169
    → [4.113/8.169, 2.028/8.169, 2.028/8.169] = [0.503, 0.248, 0.248]

  A1 = [[0.333, 0.333, 0.333],
        [0.673, 0.164, 0.164],
        [0.503, 0.248, 0.248]]

head_1 = A1 @ V1:
  V1 = [[1, 0], [1, 1], [1, 1]]

  词1: [0.333*1+0.333*1+0.333*1, 0.333*0+0.333*1+0.333*1] = [1.000, 0.667]
  词2: [0.673*1+0.164*1+0.164*1, 0.673*0+0.164*1+0.164*1] = [1.000, 0.327]
  词3: [0.503*1+0.248*1+0.248*1, 0.503*0+0.248*1+0.248*1] = [1.000, 0.497]

  head_1 = [[1.000, 0.667],
            [1.000, 0.327],
            [1.000, 0.497]]

头2的计算:

scores2 = Q2 @ K2^T / √2

Q2 @ K2^T(先不除√2):
  K2^T = [[0, 2, 1],
           [2, 0, 1]]

  词1 [0,2] · K2^T: [0*0+2*2, 0*2+2*0, 0*1+2*1] = [4, 0, 2]
  词2 [2,0] · K2^T: [2*0+0*2, 2*2+0*0, 2*1+0*1] = [0, 4, 2]
  词3 [1,1] · K2^T: [1*0+1*2, 1*2+1*0, 1*1+1*1] = [2, 2, 2]

  Q2 @ K2^T = [[4, 0, 2],
               [0, 4, 2],
               [2, 2, 2]]

除以 √2 ≈ 1.414:
  scores2 = [[2.828, 0.000, 1.414],
             [0.000, 2.828, 1.414],
             [1.414, 1.414, 1.414]]

softmax:
  词1行 [2.828, 0.000, 1.414]:
    e^2.828=16.93, e^0=1, e^1.414=4.113, 总和=22.04
    → [16.93/22.04, 1/22.04, 4.113/22.04] = [0.768, 0.045, 0.187]
  词2行 [0.000, 2.828, 1.414]:
    e^0=1, e^2.828=16.93, e^1.414=4.113, 总和=22.04
    → [1/22.04, 16.93/22.04, 4.113/22.04] = [0.045, 0.768, 0.187]
  词3行 [1.414, 1.414, 1.414]:三个值相等
    → [0.333, 0.333, 0.333]

  A2 = [[0.768, 0.045, 0.187],
        [0.045, 0.768, 0.187],
        [0.333, 0.333, 0.333]]

head_2 = A2 @ V2:
  V2 = [[0, 1], [1, 1], [1, 1]]

  词1: [0.768*0+0.045*1+0.187*1, 0.768*1+0.045*1+0.187*1] = [0.232, 1.000]
  词2: [0.045*0+0.768*1+0.187*1, 0.045*1+0.768*1+0.187*1] = [0.955, 1.000]
  词3: [0.333*0+0.333*1+0.333*1, 0.333*1+0.333*1+0.333*1] = [0.667, 1.000]

  head_2 = [[0.232, 1.000],
            [0.955, 1.000],
            [0.667, 1.000]]

观察:头1的注意力矩阵 A1 中词2强烈关注自己(0.673),头2的 A2 中词1强烈关注自己(0.768)、词2强烈关注自己(0.768)——两个头确实学到了不同的关注模式。


第三步:拼接所有头的输出

head_1 = [[1.000, 0.667],    head_2 = [[0.232, 1.000],
           [1.000, 0.327],              [0.955, 1.000],
           [1.000, 0.497]]              [0.667, 1.000]]

Concat(head_1, head_2) → 沿最后一维拼接,shape: [3, 4]

Concat = [[1.000, 0.667, 0.232, 1.000],   ← 词1
          [1.000, 0.327, 0.955, 1.000],   ← 词2
          [1.000, 0.497, 0.667, 1.000]]   ← 词3

第四步:输出线性投影

W_O 的维度:[4, 4](h×d_v → d_model,即 4→4)

取一个具体的 W_O(实际训练中由梯度下降学习得到,这里固定数值演示融合效果):
W_O = [[1,  0,  1,  0],
       [0,  1,  0,  1],
       [1,  1,  0,  0],
       [0,  0,  1,  1]]

Concat = [[1.000, 0.667, 0.232, 1.000],
          [1.000, 0.327, 0.955, 1.000],
          [1.000, 0.497, 0.667, 1.000]]

输出 = Concat @ W_O,逐行计算:

词1 [1.000, 0.667, 0.232, 1.000]:
  列0: 1.000*1 + 0.667*0 + 0.232*1 + 1.000*0 = 1.232
  列1: 1.000*0 + 0.667*1 + 0.232*1 + 1.000*0 = 0.899
  列2: 1.000*1 + 0.667*0 + 0.232*0 + 1.000*1 = 2.000
  列3: 1.000*0 + 0.667*1 + 0.232*0 + 1.000*1 = 1.667

词2 [1.000, 0.327, 0.955, 1.000]:
  列0: 1.000*1 + 0.327*0 + 0.955*1 + 1.000*0 = 1.955
  列1: 1.000*0 + 0.327*1 + 0.955*1 + 1.000*0 = 1.282
  列2: 1.000*1 + 0.327*0 + 0.955*0 + 1.000*1 = 2.000
  列3: 1.000*0 + 0.327*1 + 0.955*0 + 1.000*1 = 1.327

词3 [1.000, 0.497, 0.667, 1.000]:
  列0: 1.000*1 + 0.497*0 + 0.667*1 + 1.000*0 = 1.667
  列1: 1.000*0 + 0.497*1 + 0.667*1 + 1.000*0 = 1.164
  列2: 1.000*1 + 0.497*0 + 0.667*0 + 1.000*1 = 2.000
  列3: 1.000*0 + 0.497*1 + 0.667*0 + 1.000*1 = 1.497

输出 = [[1.232, 0.899, 2.000, 1.667],   ← 词1
        [1.955, 1.282, 2.000, 1.327],   ← 词2
        [1.667, 1.164, 2.000, 1.497]]   ← 词3

输出 shape 为 $[3, 4]$,与输入 X 的 shape $[3, 4]$ 完全一致。 W_O 的作用是将两个头的信息线性混合融合:每个输出维度都综合了来自 head_1 和 head_2 的不同特征,而不是简单地保留各头的独立输出。


五、维度变化总览

输入 X:          [B, T, d_model]      例:[2, 10, 512]
                        ↓ 每个头独立投影
Q_i, K_i, V_i:  [B, T, d_k]          例:[2, 10, 64]   (d_k = 512/8 = 64)
                        ↓ Scaled Dot-Product Attention
head_i:          [B, T, d_v]          例:[2, 10, 64]
                        ↓ Concat 所有头
Concat:          [B, T, h × d_v]      例:[2, 10, 512]  (8×64=512,恢复原维度)
                        ↓ 输出投影 W_O
输出:            [B, T, d_model]      例:[2, 10, 512]

六、实际模型中的头数配置

模型 $d_{model}$ 头数 $h$ 每头维度 $d_k$
Transformer(原论文) 512 8 64
BERT-Base 768 12 64
BERT-Large 1024 16 64
GPT-3 12288 96 128
LLaMA-7B 4096 32 128

规律:每头维度 $d_k$ 通常固定在 64 或 128,增大模型主要靠增加头数和 $d_{model}$。


七、MHA 的工程实现技巧

实际代码中,不会真的创建 h 个独立矩阵,而是用一个大矩阵一次性投影,再 reshape 成多头,效率更高。

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 每个头的维度

        # 用一个大矩阵同时投影所有头的 Q、K、V(等价于 h 个独立矩阵)
        self.W_Q = nn.Linear(d_model, d_model)  # [d_model, d_model]
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)  # 输出投影

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        """将 [B, T, d_model] reshape 成 [B, h, T, d_k]"""
        B, T, _ = x.shape
        x = x.view(B, T, self.num_heads, self.d_k)
        return x.transpose(1, 2)  # [B, h, T, d_k]

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        B, T, _ = x.shape

        # 第一步:线性投影,生成 Q、K、V
        Q = self.split_heads(self.W_Q(x))  # [B, h, T, d_k]
        K = self.split_heads(self.W_K(x))  # [B, h, T, d_k]
        V = self.split_heads(self.W_V(x))  # [B, h, T, d_k]

        # 第二步:Scaled Dot-Product Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)  # [B, h, T, T]

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attention_weights = F.softmax(scores, dim=-1)  # [B, h, T, T]
        attended = torch.matmul(attention_weights, V)  # [B, h, T, d_k]

        # 第三步:合并多头
        attended = attended.transpose(1, 2).contiguous()  # [B, T, h, d_k]
        attended = attended.view(B, T, self.d_model)       # [B, T, d_model]

        # 第四步:输出投影
        output = self.W_O(attended)  # [B, T, d_model]
        return output


# 使用示例
model = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512)   # batch=2, 序列长度=10, 维度=512
output = model(x)
print(output.shape)            # torch.Size([2, 10, 512])

八、MHA 的变体详解

变体演进的核心驱动力:推理时 KV Cache 占用显存太大、读写太慢。 每生成一个 token,都要把所有历史 token 的 K、V 从显存读出来做 Attention,这是推理速度的主要瓶颈。


8.1 MQA(Multi-Query Attention,2019)

核心改动:所有头共享同一组 K 和 V,只有 Q 保持多头。

MHA:Q1 K1 V1 | Q2 K2 V2 | Q3 K3 V3 | Q4 K4 V4   ← 每头独立 K/V
MQA:Q1        | Q2        | Q3        | Q4
          └──────────── K  V ────────────┘          ← 所有头共享一组 K/V

KV Cache 变化

  • MHA:每层缓存 2 × h × d_k 个值(h 个头各自的 K 和 V)
  • MQA:每层只缓存 2 × d_k 个值(只有 1 组 K/V)
  • 节省比例:h 倍(如 h=32,节省 32 倍 KV Cache)

代价:K/V 表达能力下降,多个头被迫用同一个 K/V 做 Attention,模型性能有所损失。

代表模型:PaLM、Falcon、早期 Gemini


8.2 GQA(Grouped-Query Attention,2023)

核心改动:把 h 个头分成 g 组,同一组内的头共享一组 K/V。MHA 是 g=h 的特例,MQA 是 g=1 的特例。

MHA(g=h=4):Q1 K1 V1 | Q2 K2 V2 | Q3 K3 V3 | Q4 K4 V4
GQA(g=2):  Q1 Q2      | Q3 Q4
               └─ K1 V1 ─┘  └─ K2 V2 ─┘          ← 每组共享一对 K/V
MQA(g=1):  Q1 Q2 Q3 Q4
               └────── K V ──────┘

KV Cache 变化

  • 每层缓存 2 × g × d_k 个值
  • g 越小,节省越多;g=h 退化为 MHA,g=1 退化为 MQA

优势:在 KV Cache 节省和模型性能之间取得平衡,实践中 g 取 h/4 或 h/8 效果接近 MHA。

代表模型:LLaMA-2/3、Mistral、Qwen2、Gemma2


8.3 MLA(Multi-head Latent Attention,2024,DeepSeek 提出)

核心改动:不直接缓存 K/V,而是缓存一个低秩压缩向量 c,推理时从 c 还原出 K/V。

原理

标准做法:  X → W_K → K(维度 d_k × h)  直接缓存 K
MLA 做法:  X → W_下投影 → c(维度 d_c,d_c << d_k × h)→ W_上投影 → K
                              ↑
                         只缓存这个低维向量 c

KV Cache 变化

  • 标准 MHA:每 token 每层缓存 2 × h × d_k
  • MLA:每 token 每层只缓存 d_c 维(DeepSeek-V2 中 d_c = 512,而 2×h×d_k = 2×128×128 = 32768)
  • 实际节省约 93.3% 的 KV Cache

额外优势:MLA 的表达能力经论文(TransMLA)证明强于 GQA,因为低秩压缩保留了跨头的信息交互,而 GQA 的共享 K/V 是对头的硬性绑定。

代价:推理时需要额外的上投影计算,但因为显存读写量大幅减少,整体推理速度仍然更快。

代表模型:DeepSeek-V2、DeepSeek-V3、DeepSeek-R1


8.4 Flash Attention(2022,工程优化,非结构变体)

重要区分:Flash Attention 不改变 Attention 的数学结果,它是一种 GPU 内存访问的工程优化,和 MHA/GQA/MLA 是正交的概念——可以叠加使用。

问题背景:标准 Attention 计算时,注意力矩阵 scores(shape: [B, h, T, T])需要完整写入 GPU HBM(高带宽内存),当序列长度 T=4096 时,这个矩阵有 4096×4096 = 1600 万个元素,读写极慢。

Flash Attention 的解法:分块(Tiling)计算,把 Q/K/V 分成小块,在 GPU SRAM(片上缓存,速度比 HBM 快 10 倍以上)中完成计算,永远不把完整的注意力矩阵写回 HBM

标准 Attention 内存访问路径:
  HBM → 读 Q,K → SRAM 计算 scores → HBM 写 scores
  HBM → 读 scores → SRAM softmax → HBM 写 softmax(scores)
  HBM → 读 softmax(scores), V → SRAM 计算输出 → HBM 写输出
  (3次大规模 HBM 读写)

Flash Attention 内存访问路径:
  分块循环:HBM → 读一小块 Q,K,V → SRAM 内完成所有计算 → HBM 写输出
  (只有 1 次 HBM 写,中间结果全在 SRAM)

效果

  • 速度提升 2~4 倍(Flash Attention 2),Flash Attention 3 针对 H100 进一步优化
  • 显存占用从 O(T²) 降为 O(T)(不存完整注意力矩阵)
  • 支持更长的序列长度

版本演进

  • Flash Attention 1(2022):提出分块计算思路
  • Flash Attention 2(2023):优化并行策略,减少非矩阵乘法运算,速度再提升约 2 倍
  • Flash Attention 3(2024):针对 Hopper 架构(H100)优化,利用异步流水线,速度达到 H100 理论峰值的 75%

使用现状:几乎所有现代开源大模型训练和推理都默认启用 Flash Attention 2,它和 GQA/MLA 可以同时使用。


8.5 各变体对比总览

变体 Q 头数 K/V 头数 KV Cache 大小 表达能力 代表模型
MHA h h 2×h×d_k ⭐⭐⭐⭐⭐ BERT、GPT-2、GPT-3
MQA h 1 2×d_k ⭐⭐⭐ PaLM、Falcon
GQA h g(1<g<h) 2×g×d_k ⭐⭐⭐⭐ LLaMA-2/3、Mistral、Qwen2
MLA h 低秩压缩 c d_c(远小于 2×h×d_k) ⭐⭐⭐⭐⭐ DeepSeek-V2/V3/R1
Flash Attention 不变(工程优化) 不变 几乎所有现代模型

2025 年主流开源模型的选择

  • Meta LLaMA-3:GQA + Flash Attention 2
  • Mistral / Mixtral:GQA + Flash Attention 2
  • Qwen2 / Qwen2.5:GQA + Flash Attention 2
  • DeepSeek-V3 / R1:MLA + Flash Attention 2(MLA 是目前最先进的注意力机制)
  • Google Gemma2:GQA + Flash Attention 2

结论:Flash Attention 2 几乎是所有先进开源模型的标配(工程层面),注意力结构层面 GQA 是主流,DeepSeek 的 MLA 是目前表达能力最强、KV Cache 最省的方案,代表最新方向。


九、多层 Transformer 堆叠(Multi-Layer)

多头 vs 多层,两个不同维度

多头(Multi-Head) 多层(Multi-Layer)
方向 同一层内并行 层与层之间串行
解决的问题 同时捕捉多种语义关系 逐层抽象,从低层特征到高层语义
数据流 输入 → 分成 h 份 → 各自 Attention → Concat → 输出 第1层输出 → 第2层输入 → … → 第N层输出

单个 Transformer Layer 的完整结构

每一层不只有 MHA,还包含 FFN 和残差连接:

输入 X
  │
  ├─→ LayerNorm → MHA → + ←── 残差(直接加上输入 X)
  │                     │
  │                     ↓
  └─────────────────→ LayerNorm → FFN → + ←── 残差
                                        │
                                        ↓
                                    该层输出(送入下一层)
  • 残差连接:防止梯度消失,让深层网络可训练(GPT-3 有 96 层,没有残差根本训不动)
  • FFN:两层线性变换 + 激活函数,维度先升后降(通常 d_model → 4×d_model → d_model),负责非线性特征变换
  • LayerNorm:稳定训练,现代模型多用 Pre-Norm(在 MHA/FFN 之前做归一化)

多层数据流示意

输入 tokens: ["今", "天", "天", "气", "好"]
  ↓ Embedding + Positional Encoding
X: [B, T, d_model]

━━━━━━━━━━━━━━ 第1层(有自己独立的 W_Q/W_K/W_V/W_O/FFN 参数)━━━━━━━━━━━━━━
  MHA(h 个头并行)→ Concat → W_O → 残差 → LayerNorm
  FFN → 残差 → LayerNorm
  输出 H1: [B, T, d_model]   ← shape 不变,内容变了(低层特征:词法、局部语法)

━━━━━━━━━━━━━━ 第2层(参数与第1层完全独立,不共享)━━━━━━━━━━━━━━
  输入是 H1,不是原始 X
  输出 H2: [B, T, d_model]   ← 中层特征:句法结构、指代关系

━━━━━━━━━━━━━━ 第N层 ━━━━━━━━━━━━━━
  输出 HN: [B, T, d_model]   ← 高层特征:语义、推理、抽象概念

关键:每层 shape 完全相同,但数值不同——这就是多层的本质,维度不变,语义逐层深化。

实际模型的层数配置

模型 层数 头数 d_model
BERT-Base 12 12 768
GPT-2(小) 12 12 768
LLaMA-3-8B 32 32 4096
LLaMA-3-70B 80 64 8192
GPT-3 96 96 12288

完整代码:多头 + 多层堆叠

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        B, T, _ = x.shape
        return x.view(B, T, self.num_heads, self.d_k).transpose(1, 2)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        B, T, _ = x.shape
        Q = self.split_heads(self.W_Q(x))
        K = self.split_heads(self.W_K(x))
        V = self.split_heads(self.W_V(x))
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        weights = self.dropout(F.softmax(scores, dim=-1))
        attended = torch.matmul(weights, V)
        attended = attended.transpose(1, 2).contiguous().view(B, T, self.d_model)
        return self.W_O(attended)


class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear2(self.dropout(F.relu(self.linear1(x))))


class TransformerLayer(nn.Module):
    """
    单个 Transformer 层 = MHA + FFN + 两个残差连接 + 两个 LayerNorm。
    每一层有自己独立的参数,层与层之间不共享。
    """
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        # Pre-Norm 风格:先 LayerNorm,再 MHA,再残差
        x = x + self.dropout(self.mha(self.norm1(x), mask))
        x = x + self.dropout(self.ffn(self.norm2(x)))
        return x  # shape 不变:[B, T, d_model]


class TransformerEncoder(nn.Module):
    """
    多层 Transformer 堆叠。
    num_layers 个 TransformerLayer 串行连接,每层参数独立。
    """
    def __init__(
        self,
        num_layers: int,
        d_model: int,
        num_heads: int,
        d_ff: int,
        vocab_size: int,
        max_seq_len: int,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_seq_len, d_model)
        self.dropout = nn.Dropout(dropout)
        # 核心:num_layers 个独立的层,每层参数互不共享
        self.layers = nn.ModuleList([
            TransformerLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.final_norm = nn.LayerNorm(d_model)

    def forward(
        self, token_ids: torch.Tensor, mask: torch.Tensor = None
    ) -> tuple[torch.Tensor, list[torch.Tensor]]:
        B, T = token_ids.shape
        positions = torch.arange(T, device=token_ids.device).unsqueeze(0)
        x = self.dropout(self.token_embedding(token_ids) + self.pos_embedding(positions))

        layer_outputs = []
        for layer in self.layers:
            x = layer(x, mask)              # 上一层输出直接作为下一层输入
            layer_outputs.append(x.clone())  # 记录每层输出,便于观察

        return self.final_norm(x), layer_outputs


# 使用示例
if __name__ == "__main__":
    model = TransformerEncoder(
        num_layers=4, d_model=64, num_heads=4,
        d_ff=256, vocab_size=1000, max_seq_len=32,
    )
    total_params = sum(p.numel() for p in model.parameters())
    print(f"总参数量:{total_params:,},每层参数量:{total_params // 4:,}(4层各自独立)")

    token_ids = torch.randint(0, 1000, (2, 10))  # batch=2,序列长度=10
    final_output, layer_outputs = model(token_ids)

    for i, out in enumerate(layer_outputs):
        print(f"第{i+1}层输出 shape: {out.shape}  均值={out.mean():.4f}  std={out.std():.4f}")
    print(f"最终输出 shape: {final_output.shape}")
    # 每层 shape 相同 [2, 10, 64],但数值不同——逐层抽象的体现

十、一句话总结

MHA = 把输入投影到 h 个低维子空间,每个子空间独立做 Self-Attention,捕捉不同语义关系,最后拼接并投影回原维度。

多层堆叠 = 把 MHA + FFN + 残差 + LayerNorm 作为一个 Block,串行堆叠 N 次,每层有独立参数,逐层从低层特征抽象到高层语义。

核心公式记忆:

Q_i, K_i, V_i = X @ W_Q^i,  X @ W_K^i,  X @ W_V^i
head_i = softmax(Q_i K_i^T / √d_k) · V_i
MHA输出 = Concat(head_1...head_h) @ W_O

单层输出 = LayerNorm(X + FFN(LayerNorm(X + MHA(X))))
多层 = 第(n)层输出 → 第(n+1)层输入,串行 N 次