一句话:KV Cache 是推理时把已经算过的 Key 和 Value 缓存起来,避免每生成一个新 token 都重复计算历史 token 的 K/V,是大模型推理加速的核心技术。


一、为什么需要 KV Cache?

大模型生成 token 的方式:自回归(Auto-regressive)

大模型生成文本是一个 token 一个 token 地生成的,每次生成下一个 token 时,都要把所有历史 token 作为输入重新过一遍 Transformer。

输入:  "今天天气"
生成:  "今天天气" → "很"
生成:  "今天天气很" → "好"
生成:  "今天天气很好" → "!"

每一步生成,Attention 都要计算:

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$

其中 Q、K、V 都来自当前所有 token(包括历史的)。

没有 KV Cache 时的重复计算

假设已经生成了 t 个 token,现在要生成第 t+1 个:

第1步生成"很":
  对 token ["今","天","天","气"] 全部计算 K、V
  → K = [K_今, K_天, K_天, K_气]
  → V = [V_今, V_天, V_天, V_气]

第2步生成"好":
  对 token ["今","天","天","气","很"] 全部计算 K、V
  → K = [K_今, K_天, K_天, K_气, K_很]   ← 前4个和上一步完全一样!
  → V = [V_今, V_天, V_天, V_气, V_很]   ← 前4个和上一步完全一样!

第3步生成"!":
  对 token ["今","天","天","气","很","好"] 全部计算 K、V
  → 前5个 K/V 和上一步完全一样!

历史 token 的 K/V 每次都重新算,纯属浪费。

生成长度为 T 的序列,总计算量是 O(T²),随序列变长急剧增加。


二、KV Cache 的原理

核心思想

历史 token 的 K 和 V 不会随新 token 的生成而改变(在 Decoder-only 模型中,历史 token 不会 attend 到未来 token,所以它们的 K/V 是固定的)。

因此,把每一层每个历史 token 的 K/V 存下来,下次直接读取,不再重新计算。

有 KV Cache 后的计算流程

第1步生成"很":
  计算所有 token 的 K/V:[K_今, K_天, K_天, K_气]
  → 存入 Cache:cache_K = [K_今, K_天, K_天, K_气]
                cache_V = [V_今, V_天, V_天, V_气]

第2步生成"好":
  只计算新 token "很" 的 K/V:K_很, V_很
  → 从 Cache 读出历史:[K_今, K_天, K_天, K_气]
  → 拼接:K = [K_今, K_天, K_天, K_气, K_很]  ← 直接用,不重算
  → 更新 Cache:cache_K = [K_今, K_天, K_天, K_气, K_很]

第3步生成"!":
  只计算新 token "好" 的 K/V:K_好, V_好
  → 从 Cache 读出历史,拼接,继续...

每步只需计算 1 个新 token 的 K/V,历史的全部从缓存读取。

计算量从 O(T²) 降为 O(T),速度提升显著。


三、KV Cache 的显存占用计算

KV Cache 的显存占用是推理时显存的主要来源之一,计算公式:

$$ \text{KV Cache 大小} = 2 \times L \times h \times d_k \times T \times \text{bytes_per_element} $$

参数 含义
2 K 和 V 各一份
L Transformer 层数
h 注意力头数
d_k 每个头的维度
T 序列长度(已生成的 token 数)
bytes_per_element 数据精度(fp16=2字节,fp32=4字节,int8=1字节)

具体例子:LLaMA-3-8B 的 KV Cache

LLaMA-3-8B 的参数:

  • 层数 L = 32
  • 头数 h = 32(Q头),KV头数 = 8(GQA,g=8)
  • 每头维度 d_k = 128
  • 精度 fp16(2字节)

序列长度 T = 4096 时的 KV Cache:

KV Cache = 2 × 32层 × 8个KV头 × 128维 × 4096 token × 2字节
         = 2 × 32 × 8 × 128 × 4096 × 2
         = 536,870,912 字节
         ≈ 512 MB

序列长度 T = 32768(32K 上下文)时:

KV Cache = 512 MB × (32768 / 4096) = 512 MB × 8 = 4 GB

这就是为什么长上下文推理显存消耗巨大,也是 MQA/GQA/MLA 要压缩 KV Cache 的根本原因。


四、KV Cache 在哪里用?

只在 Decoder(生成阶段)使用

场景 是否用 KV Cache 原因
GPT 类模型推理(生成) ✅ 用 自回归生成,历史 K/V 固定
BERT 类模型推理(理解) ❌ 不用 一次性处理全部输入,不逐步生成
Encoder-Decoder(如 T5) Decoder 部分用 Encoder 输出固定,Decoder 自回归
训练阶段 ❌ 不用 训练时并行处理所有 token,不需要逐步生成

Prefill 阶段 vs Decode 阶段

大模型推理分两个阶段:

Prefill(预填充)阶段:
  输入:完整的 prompt(如 "请帮我写一首诗:\n")
  操作:一次性并行计算所有 prompt token 的 K/V,存入 Cache
  特点:计算密集,GPU 利用率高

Decode(解码)阶段:
  输入:每次只有 1 个新生成的 token
  操作:计算新 token 的 K/V,从 Cache 读取历史 K/V,拼接后做 Attention
  特点:内存带宽密集(主要时间花在读 Cache),GPU 利用率低

五、具体计算过程举例

设定

  • 模型:单层 Transformer,2个注意力头,每头维度 d_k = 2
  • 已有 prompt:[“A”, “B”],现在要生成第3个 token

Prefill 阶段(处理 prompt “A B”)

固定投影矩阵(整个例子全程使用这三个矩阵):
  W_K = [[1, 0],    W_V = [[1, 1],    W_Q = [[1, 0],
          [0, 1],           [0, 1],           [1, 1],
          [1, 0],           [1, 0],           [0, 1],
          [0, 1]]           [0, 0]]           [1, 0]]
  (均为 [4, 2] 的矩阵,将 4 维 embedding 投影到 2 维)

token embedding(固定数值):
  x_A = [1, 0, 1, 0]
  x_B = [0, 1, 0, 1]
  x_C = [1, 1, 0, 0]   ← Decode 第1步的新 token
  x_D = [0, 0, 1, 1]   ← Decode 第2步的新 token

计算 K/V(K = x @ W_K,V = x @ W_V):

  K_A = x_A @ W_K = [1*1+0*0+1*1+0*0, 1*0+0*1+1*0+0*1] = [2, 0]
  K_B = x_B @ W_K = [0*1+1*0+0*1+1*0, 0*0+1*1+0*0+1*1] = [0, 2]

  V_A = x_A @ W_V = [1*1+0*0+1*1+0*0, 1*1+0*1+1*0+0*0] = [2, 1]
  V_B = x_B @ W_V = [0*1+1*0+0*1+1*0, 0*1+1*1+0*0+1*0] = [0, 1]

存入 KV Cache:
  cache_K = [[2, 0],   ← K_A
             [0, 2]]   ← K_B
  cache_V = [[2, 1],   ← V_A
             [0, 1]]   ← V_B

Decode 阶段第1步(生成 token C)

x_C = [1, 1, 0, 0](已在 Prefill 阶段给出)

只计算 C 的 K/V(使用与 Prefill 相同的 W_K、W_V):
  K_C = x_C @ W_K = [1*1+1*0+0*1+0*0, 1*0+1*1+0*0+0*1] = [1, 1]
  V_C = x_C @ W_V = [1*1+1*0+0*1+0*0, 1*1+1*1+0*0+0*0] = [1, 2]

从 Cache 读取历史,拼接:
  K_all = [[2, 0],   ← K_A(从 Cache 读,不重新计算)
           [0, 2],   ← K_B(从 Cache 读,不重新计算)
           [1, 1]]   ← K_C(刚计算)

  V_all = [[2, 1],   ← V_A(从 Cache 读,不重新计算)
           [0, 1],   ← V_B(从 Cache 读,不重新计算)
           [1, 2]]   ← V_C(刚计算)

Q_C = x_C @ W_Q = [1*1+1*1+0*0+0*1, 1*0+1*1+0*1+0*0] = [2, 1]

Attention 分数(Q_C 对所有 K):
  score_A = Q_C · K_A / √2 = (2*2 + 1*0) / 1.414 = 4.0 / 1.414 = 2.828
  score_B = Q_C · K_B / √2 = (2*0 + 1*2) / 1.414 = 2.0 / 1.414 = 1.414
  score_C = Q_C · K_C / √2 = (2*1 + 1*1) / 1.414 = 3.0 / 1.414 = 2.121

softmax([2.828, 1.414, 2.121]):
  e^2.828=16.930, e^1.414=4.113, e^2.121=8.337, 总和=29.380
  权重 = [16.930/29.380, 4.113/29.380, 8.337/29.380]
       = [0.576, 0.140, 0.284]

输出 = 0.576*V_A + 0.140*V_B + 0.284*V_C
     = 0.576*[2,1] + 0.140*[0,1] + 0.284*[1,2]
     = [1.152, 0.576] + [0.000, 0.140] + [0.284, 0.568]
     = [1.436, 1.284]

更新 Cache:
  cache_K 追加 K_C → [[2,0], [0,2], [1,1]]
  cache_V 追加 V_C → [[2,1], [0,1], [1,2]]

Decode 阶段第2步(生成 token D)

x_D = [0, 0, 1, 1](已在 Prefill 阶段给出)

只计算 D 的 K/V(使用与 Prefill 相同的 W_K、W_V):
  K_D = x_D @ W_K = [0*1+0*0+1*1+1*0, 0*0+0*1+1*0+1*1] = [1, 1]
  V_D = x_D @ W_V = [0*1+0*0+1*1+1*0, 0*1+0*1+1*0+1*0] = [1, 0]

从 Cache 读取历史(此时 Cache 已包含 A、B、C):
  cache_K = [[2, 0],   ← K_A(不重新计算)
             [0, 2],   ← K_B(不重新计算)
             [1, 1]]   ← K_C(不重新计算)

拼接 K_D:
  K_all = [[2, 0],   ← K_A(Cache 读取)
           [0, 2],   ← K_B(Cache 读取)
           [1, 1],   ← K_C(Cache 读取)
           [1, 1]]   ← K_D(刚计算)

  V_all = [[2, 1],   ← V_A(Cache 读取)
           [0, 1],   ← V_B(Cache 读取)
           [1, 2],   ← V_C(Cache 读取)
           [1, 0]]   ← V_D(刚计算)

Q_D = x_D @ W_Q = [0*1+0*1+1*0+1*1, 0*0+0*1+1*1+1*0] = [1, 1]

Attention 分数(Q_D 对所有 K):
  score_A = Q_D · K_A / √2 = (1*2 + 1*0) / 1.414 = 2.0 / 1.414 = 1.414
  score_B = Q_D · K_B / √2 = (1*0 + 1*2) / 1.414 = 2.0 / 1.414 = 1.414
  score_C = Q_D · K_C / √2 = (1*1 + 1*1) / 1.414 = 2.0 / 1.414 = 1.414
  score_D = Q_D · K_D / √2 = (1*1 + 1*1) / 1.414 = 2.0 / 1.414 = 1.414

softmax([1.414, 1.414, 1.414, 1.414]):
  四个值完全相等 → 权重均等 = [0.25, 0.25, 0.25, 0.25]
  (Q_D 对所有历史 token 和自身的关注度相同,符合直觉:x_D 的特征与所有历史 token 距离相等)

输出 = 0.25*V_A + 0.25*V_B + 0.25*V_C + 0.25*V_D
     = 0.25*[2,1] + 0.25*[0,1] + 0.25*[1,2] + 0.25*[1,0]
     = [0.5, 0.25] + [0.0, 0.25] + [0.25, 0.5] + [0.25, 0.0]
     = [1.0, 1.0]

更新 Cache:
  cache_K 追加 K_D → [[2,0], [0,2], [1,1], [1,1]]
  cache_V 追加 V_D → [[2,1], [0,1], [1,2], [1,0]]

对比:若没有 KV Cache,这一步需要重新计算 K_A、K_B、K_C,共 3 次矩阵乘法;
     有了 KV Cache,只计算了 K_D 这 1 次,节省了 75% 的 K/V 计算量。
     序列越长,节省比例越高(生成第 T 步时节省 (T-1)/T)。

六、KV Cache 的工程实现要点

预分配显存(静态 Cache)

推理框架通常在开始前就按最大序列长度预分配好 Cache 显存,避免动态分配的开销:

import torch

max_seq_len = 4096
num_layers = 32
num_kv_heads = 8
head_dim = 128
batch_size = 1   # 推理时通常 batch=1(单请求),服务化场景可设更大值

# 预分配 KV Cache,shape: [层数, 2, batch, kv头数, 最大序列长度, 头维度]
# 第2维的 2 代表 K 和 V 各一份(index 0 = K,index 1 = V)
kv_cache = torch.zeros(num_layers, 2, batch_size, num_kv_heads, max_seq_len, head_dim)

print(f"KV Cache 预分配显存:{kv_cache.numel() * 2 / 1024**3:.2f} GB (fp16)")
# 32层 × 2 × 1 × 8头 × 4096 × 128 × 2字节 = 0.5 GB

动态 Cache(PagedAttention,vLLM 使用)

静态预分配的问题:不同请求序列长度不同,预分配最大长度会浪费显存。

PagedAttention(vLLM 提出)借鉴操作系统虚拟内存的分页思想:

  • 把 KV Cache 切成固定大小的"页"(block)
  • 按需分配,不同请求的 Cache 页可以不连续
  • 支持多个请求共享相同 prompt 的 Cache(prefix sharing)
传统静态 Cache:
  请求A(实际用200 token):预分配4096 token的显存 → 浪费3896 token
  请求B(实际用3000 token):预分配4096 token的显存 → 浪费1096 token

PagedAttention:
  请求A:按需分配 200 token 的 Cache,用多少分多少
  请求B:按需分配 3000 token 的 Cache
  → 显存利用率大幅提升,同等显存可服务更多并发请求

PyTorch 中使用 KV Cache 的简化示例

import torch
import torch.nn as nn
import math

class AttentionWithKVCache(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        B, T, _ = x.shape
        return x.view(B, T, self.num_heads, self.d_k).transpose(1, 2)

    def forward(
        self,
        x: torch.Tensor,
        past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None
    ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:

        # 只对当前新 token 计算 Q、K、V
        Q = self.split_heads(self.W_Q(x))   # [B, h, T_new, d_k]
        K_new = self.split_heads(self.W_K(x))
        V_new = self.split_heads(self.W_V(x))

        # 拼接历史 Cache
        if past_key_value is not None:
            K_cached, V_cached = past_key_value
            K = torch.cat([K_cached, K_new], dim=2)  # [B, h, T_all, d_k]
            V = torch.cat([V_cached, V_new], dim=2)
        else:
            K, V = K_new, V_new

        # 更新后的 Cache(供下一步使用)
        updated_cache = (K, V)

        # Scaled Dot-Product Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        weights = torch.softmax(scores, dim=-1)
        attended = torch.matmul(weights, V)

        # 合并多头,输出投影
        B, h, T_new, d_k = attended.shape
        attended = attended.transpose(1, 2).contiguous().view(B, T_new, h * d_k)
        output = self.W_O(attended)

        return output, updated_cache


# 使用示例:模拟自回归生成
model = AttentionWithKVCache(d_model=64, num_heads=4)
model.eval()

with torch.no_grad():
    # Prefill:处理 prompt(5个token)
    prompt = torch.randn(1, 5, 64)
    out, kv_cache = model(prompt, past_key_value=None)
    print(f"Prefill 后 Cache K shape: {kv_cache[0].shape}")  # [1, 4, 5, 16]

    # Decode:逐步生成新 token
    for step in range(3):
        new_token = torch.randn(1, 1, 64)   # 每次只输入 1 个新 token
        out, kv_cache = model(new_token, past_key_value=kv_cache)
        print(f"Step {step+1} 后 Cache K shape: {kv_cache[0].shape}")
        # Step 1: [1, 4, 6, 16]
        # Step 2: [1, 4, 7, 16]
        # Step 3: [1, 4, 8, 16]

七、KV Cache 与注意力变体的关系

KV Cache 的显存压力催生了各种注意力变体(详见 MHA 笔记第八节):

KV Cache 太大
    ↓
MQA(2019):所有头共享 K/V → Cache 缩小 h 倍,但性能下降
    ↓
GQA(2023):分组共享 K/V → Cache 缩小 h/g 倍,性能接近 MHA
    ↓
MLA(2024):缓存低秩压缩向量 c → Cache 缩小 ~93%,性能不降反升

Flash Attention 解决的是另一个问题:不是 Cache 太大,而是 Attention 矩阵的 HBM 读写太慢。两者正交,可以同时使用。


八、常见问题

Q:训练时为什么不用 KV Cache? 训练时使用 Teacher Forcing,所有 token 并行输入,不需要逐步生成,因此不存在"历史 K/V 重复计算"的问题。

Q:KV Cache 会不会导致内存溢出(OOM)? 会。序列越长、batch 越大、模型越大,Cache 越大。解决方案:

  • 使用 GQA/MLA 减少 Cache 大小
  • 使用 PagedAttention(vLLM)提高显存利用率
  • 量化 Cache(如 KV Cache int8 量化)
  • 设置最大序列长度限制

Q:多轮对话时 KV Cache 怎么处理? 每轮对话结束后,可以选择:

  • 保留 Cache:下一轮直接在历史 Cache 上追加,速度快,但显存持续增长
  • 清空 Cache:每轮重新 Prefill,显存可控,但每轮都要重新计算历史

Q:为什么 Prefill 快、Decode 慢?

  • Prefill:大矩阵乘法,GPU 并行度高,计算密集型,GPU 利用率接近 100%
  • Decode:每次只处理 1 个 token,矩阵很小,主要时间花在从显存读 KV Cache,内存带宽密集型,GPU 利用率通常只有 10%~30%

九、一句话总结

KV Cache = 空间换时间:把历史 token 的 K/V 存在显存里,每次生成新 token 只算新的 K/V,历史的直接读取,把推理复杂度从 O(T²) 降为 O(T)。

代价是显存随序列长度线性增长,这是 MQA/GQA/MLA/PagedAttention 等技术存在的根本原因。