一句话:Mask 是在计算 softmax 之前,把某些位置的注意力分数强制设为 $-\infty$,使 softmax 后这些位置的权重变为 0,相当于"屏蔽掉"不该看的位置。


一、为什么需要 Mask?

Attention 的计算是:

$$ \text{scores} = \frac{QK^T}{\sqrt{d_k}}, \quad \text{weights} = \text{softmax}(\text{scores}), \quad \text{output} = \text{weights} \cdot V $$

默认情况下,每个 token 的 Q 会和序列中所有 token 的 K 做点积,包括:

  • 无意义的 [PAD] 填充 token
  • 未来还没生成的 token(训练时)

这两种情况都需要用 Mask 屏蔽掉。


二、两种 Mask

2.1 Padding Mask(填充遮罩)

问题:一个 batch 里不同句子长度不同,需要用 [PAD] token 补齐到相同长度,但 [PAD] 是无意义的,不应该被 Attention 到。

batch 里两个句子(补齐到长度5):
  句子1: ["今", "天", "好", "[PAD]", "[PAD]"]   ← 实际长度3,补了2个PAD
  句子2: ["天", "气", "真", "不",    "错"   ]   ← 实际长度5,无需补

Padding Mask(1=有效位置,0=PAD位置):
  句子1: [1, 1, 1, 0, 0]
  句子2: [1, 1, 1, 1, 1]

屏蔽效果:

句子1的注意力分数矩阵(屏蔽前):
         今    天    好   PAD  PAD
  今  → [0.8, 0.3, 0.5, 0.2, 0.1]
  天  → [0.4, 0.9, 0.3, 0.3, 0.2]
  好  → [0.5, 0.4, 0.7, 0.1, 0.1]

屏蔽后(PAD列设为 -∞):
         今    天    好    PAD   PAD
  今  → [0.8, 0.3, 0.5,  -∞,   -∞ ]
  天  → [0.4, 0.9, 0.3,  -∞,   -∞ ]
  好  → [0.5, 0.4, 0.7,  -∞,   -∞ ]

softmax 后 PAD 列权重 = 0,完全不影响输出

适用场景:训练和推理(有 padding 时)都需要。


2.2 Causal Mask(因果遮罩,也叫 Look-ahead Mask)

问题:GPT 类自回归模型训练时,预测第 t 个 token 时不能看到未来的 token,否则就是"作弊"——模型直接抄答案,学不到任何东西。

序列: ["今", "天", "天", "气", "好"]
训练目标:
  看到"今",预测"天"
  看到"今天",预测"天"
  看到"今天天",预测"气"
  ...

Causal Mask(下三角矩阵,1=可以看,0=不能看):
         今  天  天  气  好
  今  → [ 1,  0,  0,  0,  0]   ← "今"只能看自己
  天  → [ 1,  1,  0,  0,  0]   ← "天"能看"今"和自己
  天  → [ 1,  1,  1,  0,  0]
  气  → [ 1,  1,  1,  1,  0]
  好  → [ 1,  1,  1,  1,  1]   ← "好"能看所有历史

适用场景:只在训练 GPT/Decoder 类模型时需要。推理时自回归逐步生成,天然只能看到历史,不需要此 mask。


三、训练 vs 推理时传什么 Mask?

场景 传入的 Mask
训练(GPT/Decoder) Causal Mask(下三角)AND Padding Mask
训练(BERT/Encoder) 只有 Padding Mask(双向 Attention,可以看全部)
推理,单条,无 padding None(不需要任何 mask)
推理,batch,有 padding 只有 Padding Mask
推理,自回归逐步生成 None(每步只输入 1 个新 token,天然无未来信息)

四、完整代码示例

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


def make_padding_mask(token_ids: torch.Tensor, pad_token_id: int = 0) -> torch.Tensor:
    """
    生成 Padding Mask。
    输入:token_ids [B, T]
    输出:mask [B, 1, 1, T],可广播到 [B, num_heads, T_q, T_k]
    值:PAD 位置为 False(屏蔽),非 PAD 位置为 True(保留)
    """
    return (token_ids != pad_token_id).unsqueeze(1).unsqueeze(2)  # [B, 1, 1, T]


def make_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
    """
    生成 Causal Mask(下三角矩阵)。
    输出:mask [1, 1, T, T],可广播到 [B, num_heads, T, T]
    值:下三角为 True(可以看),上三角为 False(屏蔽未来)
    """
    return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
                      ).unsqueeze(0).unsqueeze(0)  # [1, 1, T, T]


def make_decoder_mask(token_ids: torch.Tensor, pad_token_id: int = 0) -> torch.Tensor:
    """
    GPT 训练时的完整 Mask = Causal Mask AND Padding Mask。
    两者取交集:既不能看未来,也不能看 PAD。
    """
    B, T = token_ids.shape
    causal = make_causal_mask(T, token_ids.device)           # [1, 1, T, T]
    padding = make_padding_mask(token_ids, pad_token_id)     # [B, 1, 1, T]
    return causal & padding                                   # [B, 1, T, T]


class MultiHeadAttentionWithMask(nn.Module):
    """带 Mask 支持的 MHA"""

    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.d_model = d_model
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        B, T, _ = x.shape
        return x.view(B, T, self.num_heads, self.d_k).transpose(1, 2)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        B, T, _ = x.shape
        Q = self.split_heads(self.W_Q(x))  # [B, h, T, d_k]
        K = self.split_heads(self.W_K(x))
        V = self.split_heads(self.W_V(x))

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)  # [B, h, T, T]

        if mask is not None:
            # mask=False 的位置设为 -inf,softmax 后权重变为 0
            scores = scores.masked_fill(mask == False, float('-inf'))

        weights = F.softmax(scores, dim=-1)
        attended = torch.matmul(weights, V)
        attended = attended.transpose(1, 2).contiguous().view(B, T, self.d_model)
        return self.W_O(attended)


# ── 使用示例 ──────────────────────────────────────────────────────────────

if __name__ == "__main__":
    d_model = 32
    num_heads = 4
    pad_token_id = 0

    model = MultiHeadAttentionWithMask(d_model=d_model, num_heads=num_heads)
    model.eval()

    # ── 场景1:训练时(GPT/Decoder),有 PAD,需要 Causal + Padding Mask ──
    token_ids_train = torch.tensor([
        [3, 7, 2, 0, 0],   # 句子1,后两个是 PAD
        [5, 1, 8, 4, 6],   # 句子2,无 PAD
    ])
    x_train = torch.randn(2, 5, d_model)
    decoder_mask = make_decoder_mask(token_ids_train, pad_token_id)

    print("=== 训练场景(GPT Decoder)===")
    print(f"token_ids shape: {token_ids_train.shape}")
    print(f"decoder_mask shape: {decoder_mask.shape}")   # [2, 1, 5, 5]
    print("句子1的 mask(下三角 + 屏蔽PAD列):")
    print(decoder_mask[0, 0].int())
    # [[1, 0, 0, 0, 0],
    #  [1, 1, 0, 0, 0],
    #  [1, 1, 1, 0, 0],   ← PAD列(第4、5列)被屏蔽
    #  [1, 1, 1, 0, 0],   ← PAD行也只能看到有效列
    #  [1, 1, 1, 0, 0]]

    with torch.no_grad():
        output_train = model(x_train, mask=decoder_mask)
    print(f"输出 shape: {output_train.shape}\n")  # [2, 5, 32]

    # ── 场景2:训练时(BERT/Encoder),有 PAD,只需 Padding Mask ──
    token_ids_bert = torch.tensor([
        [3, 7, 2, 0, 0],
        [5, 1, 8, 4, 6],
    ])
    padding_mask = make_padding_mask(token_ids_bert, pad_token_id)  # [2, 1, 1, 5]

    print("=== 训练场景(BERT Encoder)===")
    print(f"padding_mask shape: {padding_mask.shape}")
    print("句子1的 padding_mask:", padding_mask[0, 0, 0].int())  # [1, 1, 1, 0, 0]

    with torch.no_grad():
        output_bert = model(x_train, mask=padding_mask)
    print(f"输出 shape: {output_bert.shape}\n")

    # ── 场景3:推理时,单条输入,无 PAD,mask=None ──
    single_input = torch.randn(1, 8, d_model)  # batch=1,序列长度=8

    print("=== 推理场景(单条,无PAD)===")
    with torch.no_grad():
        output_infer = model(single_input, mask=None)  # 直接传 None
    print(f"输出 shape: {output_infer.shape}")  # [1, 8, 32]

五、Mask 在 Attention 计算中的位置

Q, K, V 计算完毕
  ↓
scores = Q @ K^T / √d_k        # [B, h, T, T]
  ↓
if mask is not None:
    scores[mask == False] = -∞  # 屏蔽不该看的位置
  ↓
weights = softmax(scores)       # -∞ 位置的权重变为 0
  ↓
output = weights @ V            # 被屏蔽的位置对输出无贡献

六、常见问题

Q:Causal Mask 为什么是下三角? 位置 i 的 token 只能看到位置 ≤ i 的 token(包括自己),矩阵的第 i 行第 j 列表示"位置 i 能否看到位置 j",因此 j ≤ i 的位置为 1,即下三角。

Q:推理时自回归生成为什么不需要 Causal Mask? 推理时每步只输入 1 个新 token(配合 KV Cache),Q 只有 1 行,K/V 是历史所有 token,天然就只能看到历史,不存在"看到未来"的问题。

Q:mask 的值为什么用 True/False 而不是 1/0? 两种写法都可以,只要在 masked_fill 时保持一致:

scores.masked_fill(mask == False, float('-inf'))  # True=保留,False=屏蔽
scores.masked_fill(mask == 0,     float('-inf'))  # 1=保留,0=屏蔽(等价)

Q:BERT 为什么不需要 Causal Mask? BERT 是 Encoder-only 模型,做的是理解任务(分类、NER 等),一次性看完整个句子,双向 Attention 是它的优势,不需要屏蔽未来。


七、一句话总结

  • Padding Mask:屏蔽 [PAD] 填充位置,训练和推理(有 padding 时)都需要
  • Causal Mask:屏蔽未来位置(下三角矩阵),只在训练 GPT/Decoder 类模型时需要
  • 推理时单条自回归生成:传 None,不需要任何 mask