一句话:Mask 是在计算 softmax 之前,把某些位置的注意力分数强制设为 $-\infty$,使 softmax 后这些位置的权重变为 0,相当于"屏蔽掉"不该看的位置。
一、为什么需要 Mask?
Attention 的计算是:
$$ \text{scores} = \frac{QK^T}{\sqrt{d_k}}, \quad \text{weights} = \text{softmax}(\text{scores}), \quad \text{output} = \text{weights} \cdot V $$
默认情况下,每个 token 的 Q 会和序列中所有 token 的 K 做点积,包括:
- 无意义的
[PAD]填充 token - 未来还没生成的 token(训练时)
这两种情况都需要用 Mask 屏蔽掉。
二、两种 Mask
2.1 Padding Mask(填充遮罩)
问题:一个 batch 里不同句子长度不同,需要用 [PAD] token 补齐到相同长度,但 [PAD] 是无意义的,不应该被 Attention 到。
batch 里两个句子(补齐到长度5):
句子1: ["今", "天", "好", "[PAD]", "[PAD]"] ← 实际长度3,补了2个PAD
句子2: ["天", "气", "真", "不", "错" ] ← 实际长度5,无需补
Padding Mask(1=有效位置,0=PAD位置):
句子1: [1, 1, 1, 0, 0]
句子2: [1, 1, 1, 1, 1]
屏蔽效果:
句子1的注意力分数矩阵(屏蔽前):
今 天 好 PAD PAD
今 → [0.8, 0.3, 0.5, 0.2, 0.1]
天 → [0.4, 0.9, 0.3, 0.3, 0.2]
好 → [0.5, 0.4, 0.7, 0.1, 0.1]
屏蔽后(PAD列设为 -∞):
今 天 好 PAD PAD
今 → [0.8, 0.3, 0.5, -∞, -∞ ]
天 → [0.4, 0.9, 0.3, -∞, -∞ ]
好 → [0.5, 0.4, 0.7, -∞, -∞ ]
softmax 后 PAD 列权重 = 0,完全不影响输出
适用场景:训练和推理(有 padding 时)都需要。
2.2 Causal Mask(因果遮罩,也叫 Look-ahead Mask)
问题:GPT 类自回归模型训练时,预测第 t 个 token 时不能看到未来的 token,否则就是"作弊"——模型直接抄答案,学不到任何东西。
序列: ["今", "天", "天", "气", "好"]
训练目标:
看到"今",预测"天"
看到"今天",预测"天"
看到"今天天",预测"气"
...
Causal Mask(下三角矩阵,1=可以看,0=不能看):
今 天 天 气 好
今 → [ 1, 0, 0, 0, 0] ← "今"只能看自己
天 → [ 1, 1, 0, 0, 0] ← "天"能看"今"和自己
天 → [ 1, 1, 1, 0, 0]
气 → [ 1, 1, 1, 1, 0]
好 → [ 1, 1, 1, 1, 1] ← "好"能看所有历史
适用场景:只在训练 GPT/Decoder 类模型时需要。推理时自回归逐步生成,天然只能看到历史,不需要此 mask。
三、训练 vs 推理时传什么 Mask?
| 场景 | 传入的 Mask |
|---|---|
| 训练(GPT/Decoder) | Causal Mask(下三角)AND Padding Mask |
| 训练(BERT/Encoder) | 只有 Padding Mask(双向 Attention,可以看全部) |
| 推理,单条,无 padding | None(不需要任何 mask) |
| 推理,batch,有 padding | 只有 Padding Mask |
| 推理,自回归逐步生成 | None(每步只输入 1 个新 token,天然无未来信息) |
四、完整代码示例
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def make_padding_mask(token_ids: torch.Tensor, pad_token_id: int = 0) -> torch.Tensor:
"""
生成 Padding Mask。
输入:token_ids [B, T]
输出:mask [B, 1, 1, T],可广播到 [B, num_heads, T_q, T_k]
值:PAD 位置为 False(屏蔽),非 PAD 位置为 True(保留)
"""
return (token_ids != pad_token_id).unsqueeze(1).unsqueeze(2) # [B, 1, 1, T]
def make_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
"""
生成 Causal Mask(下三角矩阵)。
输出:mask [1, 1, T, T],可广播到 [B, num_heads, T, T]
值:下三角为 True(可以看),上三角为 False(屏蔽未来)
"""
return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
).unsqueeze(0).unsqueeze(0) # [1, 1, T, T]
def make_decoder_mask(token_ids: torch.Tensor, pad_token_id: int = 0) -> torch.Tensor:
"""
GPT 训练时的完整 Mask = Causal Mask AND Padding Mask。
两者取交集:既不能看未来,也不能看 PAD。
"""
B, T = token_ids.shape
causal = make_causal_mask(T, token_ids.device) # [1, 1, T, T]
padding = make_padding_mask(token_ids, pad_token_id) # [B, 1, 1, T]
return causal & padding # [B, 1, T, T]
class MultiHeadAttentionWithMask(nn.Module):
"""带 Mask 支持的 MHA"""
def __init__(self, d_model: int, num_heads: int):
super().__init__()
assert d_model % num_heads == 0
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.d_model = d_model
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
B, T, _ = x.shape
return x.view(B, T, self.num_heads, self.d_k).transpose(1, 2)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
B, T, _ = x.shape
Q = self.split_heads(self.W_Q(x)) # [B, h, T, d_k]
K = self.split_heads(self.W_K(x))
V = self.split_heads(self.W_V(x))
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) # [B, h, T, T]
if mask is not None:
# mask=False 的位置设为 -inf,softmax 后权重变为 0
scores = scores.masked_fill(mask == False, float('-inf'))
weights = F.softmax(scores, dim=-1)
attended = torch.matmul(weights, V)
attended = attended.transpose(1, 2).contiguous().view(B, T, self.d_model)
return self.W_O(attended)
# ── 使用示例 ──────────────────────────────────────────────────────────────
if __name__ == "__main__":
d_model = 32
num_heads = 4
pad_token_id = 0
model = MultiHeadAttentionWithMask(d_model=d_model, num_heads=num_heads)
model.eval()
# ── 场景1:训练时(GPT/Decoder),有 PAD,需要 Causal + Padding Mask ──
token_ids_train = torch.tensor([
[3, 7, 2, 0, 0], # 句子1,后两个是 PAD
[5, 1, 8, 4, 6], # 句子2,无 PAD
])
x_train = torch.randn(2, 5, d_model)
decoder_mask = make_decoder_mask(token_ids_train, pad_token_id)
print("=== 训练场景(GPT Decoder)===")
print(f"token_ids shape: {token_ids_train.shape}")
print(f"decoder_mask shape: {decoder_mask.shape}") # [2, 1, 5, 5]
print("句子1的 mask(下三角 + 屏蔽PAD列):")
print(decoder_mask[0, 0].int())
# [[1, 0, 0, 0, 0],
# [1, 1, 0, 0, 0],
# [1, 1, 1, 0, 0], ← PAD列(第4、5列)被屏蔽
# [1, 1, 1, 0, 0], ← PAD行也只能看到有效列
# [1, 1, 1, 0, 0]]
with torch.no_grad():
output_train = model(x_train, mask=decoder_mask)
print(f"输出 shape: {output_train.shape}\n") # [2, 5, 32]
# ── 场景2:训练时(BERT/Encoder),有 PAD,只需 Padding Mask ──
token_ids_bert = torch.tensor([
[3, 7, 2, 0, 0],
[5, 1, 8, 4, 6],
])
padding_mask = make_padding_mask(token_ids_bert, pad_token_id) # [2, 1, 1, 5]
print("=== 训练场景(BERT Encoder)===")
print(f"padding_mask shape: {padding_mask.shape}")
print("句子1的 padding_mask:", padding_mask[0, 0, 0].int()) # [1, 1, 1, 0, 0]
with torch.no_grad():
output_bert = model(x_train, mask=padding_mask)
print(f"输出 shape: {output_bert.shape}\n")
# ── 场景3:推理时,单条输入,无 PAD,mask=None ──
single_input = torch.randn(1, 8, d_model) # batch=1,序列长度=8
print("=== 推理场景(单条,无PAD)===")
with torch.no_grad():
output_infer = model(single_input, mask=None) # 直接传 None
print(f"输出 shape: {output_infer.shape}") # [1, 8, 32]
五、Mask 在 Attention 计算中的位置
Q, K, V 计算完毕
↓
scores = Q @ K^T / √d_k # [B, h, T, T]
↓
if mask is not None:
scores[mask == False] = -∞ # 屏蔽不该看的位置
↓
weights = softmax(scores) # -∞ 位置的权重变为 0
↓
output = weights @ V # 被屏蔽的位置对输出无贡献
六、常见问题
Q:Causal Mask 为什么是下三角? 位置 i 的 token 只能看到位置 ≤ i 的 token(包括自己),矩阵的第 i 行第 j 列表示"位置 i 能否看到位置 j",因此 j ≤ i 的位置为 1,即下三角。
Q:推理时自回归生成为什么不需要 Causal Mask? 推理时每步只输入 1 个新 token(配合 KV Cache),Q 只有 1 行,K/V 是历史所有 token,天然就只能看到历史,不存在"看到未来"的问题。
Q:mask 的值为什么用 True/False 而不是 1/0?
两种写法都可以,只要在 masked_fill 时保持一致:
scores.masked_fill(mask == False, float('-inf')) # True=保留,False=屏蔽
scores.masked_fill(mask == 0, float('-inf')) # 1=保留,0=屏蔽(等价)
Q:BERT 为什么不需要 Causal Mask? BERT 是 Encoder-only 模型,做的是理解任务(分类、NER 等),一次性看完整个句子,双向 Attention 是它的优势,不需要屏蔽未来。
七、一句话总结
- Padding Mask:屏蔽
[PAD]填充位置,训练和推理(有 padding 时)都需要- Causal Mask:屏蔽未来位置(下三角矩阵),只在训练 GPT/Decoder 类模型时需要
- 推理时单条自回归生成:传
None,不需要任何 mask