一句话:KV Cache 是推理时把已经算过的 Key 和 Value 缓存起来,避免每生成一个新 token 都重复计算历史 token 的 K/V,是大模型推理加速的核心技术。
一、为什么需要 KV Cache?
大模型生成 token 的方式:自回归(Auto-regressive)
大模型生成文本是一个 token 一个 token 地生成的,每次生成下一个 token 时,都要把所有历史 token 作为输入重新过一遍 Transformer。
输入: "今天天气"
生成: "今天天气" → "很"
生成: "今天天气很" → "好"
生成: "今天天气很好" → "!"
每一步生成,Attention 都要计算:
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$
其中 Q、K、V 都来自当前所有 token(包括历史的)。
没有 KV Cache 时的重复计算
假设已经生成了 t 个 token,现在要生成第 t+1 个:
第1步生成"很":
对 token ["今","天","天","气"] 全部计算 K、V
→ K = [K_今, K_天, K_天, K_气]
→ V = [V_今, V_天, V_天, V_气]
第2步生成"好":
对 token ["今","天","天","气","很"] 全部计算 K、V
→ K = [K_今, K_天, K_天, K_气, K_很] ← 前4个和上一步完全一样!
→ V = [V_今, V_天, V_天, V_气, V_很] ← 前4个和上一步完全一样!
第3步生成"!":
对 token ["今","天","天","气","很","好"] 全部计算 K、V
→ 前5个 K/V 和上一步完全一样!
历史 token 的 K/V 每次都重新算,纯属浪费。
生成长度为 T 的序列,总计算量是 O(T²),随序列变长急剧增加。
二、KV Cache 的原理
核心思想
历史 token 的 K 和 V 不会随新 token 的生成而改变(在 Decoder-only 模型中,历史 token 不会 attend 到未来 token,所以它们的 K/V 是固定的)。
因此,把每一层每个历史 token 的 K/V 存下来,下次直接读取,不再重新计算。
有 KV Cache 后的计算流程
第1步生成"很":
计算所有 token 的 K/V:[K_今, K_天, K_天, K_气]
→ 存入 Cache:cache_K = [K_今, K_天, K_天, K_气]
cache_V = [V_今, V_天, V_天, V_气]
第2步生成"好":
只计算新 token "很" 的 K/V:K_很, V_很
→ 从 Cache 读出历史:[K_今, K_天, K_天, K_气]
→ 拼接:K = [K_今, K_天, K_天, K_气, K_很] ← 直接用,不重算
→ 更新 Cache:cache_K = [K_今, K_天, K_天, K_气, K_很]
第3步生成"!":
只计算新 token "好" 的 K/V:K_好, V_好
→ 从 Cache 读出历史,拼接,继续...
每步只需计算 1 个新 token 的 K/V,历史的全部从缓存读取。
计算量从 O(T²) 降为 O(T),速度提升显著。
三、KV Cache 的显存占用计算
KV Cache 的显存占用是推理时显存的主要来源之一,计算公式:
$$ \text{KV Cache 大小} = 2 \times L \times h \times d_k \times T \times \text{bytes_per_element} $$
| 参数 | 含义 |
|---|---|
| 2 | K 和 V 各一份 |
| L | Transformer 层数 |
| h | 注意力头数 |
| d_k | 每个头的维度 |
| T | 序列长度(已生成的 token 数) |
| bytes_per_element | 数据精度(fp16=2字节,fp32=4字节,int8=1字节) |
具体例子:LLaMA-3-8B 的 KV Cache
LLaMA-3-8B 的参数:
- 层数 L = 32
- 头数 h = 32(Q头),KV头数 = 8(GQA,g=8)
- 每头维度 d_k = 128
- 精度 fp16(2字节)
序列长度 T = 4096 时的 KV Cache:
KV Cache = 2 × 32层 × 8个KV头 × 128维 × 4096 token × 2字节
= 2 × 32 × 8 × 128 × 4096 × 2
= 536,870,912 字节
≈ 512 MB
序列长度 T = 32768(32K 上下文)时:
KV Cache = 512 MB × (32768 / 4096) = 512 MB × 8 = 4 GB
这就是为什么长上下文推理显存消耗巨大,也是 MQA/GQA/MLA 要压缩 KV Cache 的根本原因。
四、KV Cache 在哪里用?
只在 Decoder(生成阶段)使用
| 场景 | 是否用 KV Cache | 原因 |
|---|---|---|
| GPT 类模型推理(生成) | ✅ 用 | 自回归生成,历史 K/V 固定 |
| BERT 类模型推理(理解) | ❌ 不用 | 一次性处理全部输入,不逐步生成 |
| Encoder-Decoder(如 T5) | Decoder 部分用 | Encoder 输出固定,Decoder 自回归 |
| 训练阶段 | ❌ 不用 | 训练时并行处理所有 token,不需要逐步生成 |
Prefill 阶段 vs Decode 阶段
大模型推理分两个阶段:
Prefill(预填充)阶段:
输入:完整的 prompt(如 "请帮我写一首诗:\n")
操作:一次性并行计算所有 prompt token 的 K/V,存入 Cache
特点:计算密集,GPU 利用率高
Decode(解码)阶段:
输入:每次只有 1 个新生成的 token
操作:计算新 token 的 K/V,从 Cache 读取历史 K/V,拼接后做 Attention
特点:内存带宽密集(主要时间花在读 Cache),GPU 利用率低
五、具体计算过程举例
设定
- 模型:单层 Transformer,2个注意力头,每头维度 d_k = 2
- 已有 prompt:[“A”, “B”],现在要生成第3个 token
Prefill 阶段(处理 prompt “A B”)
固定投影矩阵(整个例子全程使用这三个矩阵):
W_K = [[1, 0], W_V = [[1, 1], W_Q = [[1, 0],
[0, 1], [0, 1], [1, 1],
[1, 0], [1, 0], [0, 1],
[0, 1]] [0, 0]] [1, 0]]
(均为 [4, 2] 的矩阵,将 4 维 embedding 投影到 2 维)
token embedding(固定数值):
x_A = [1, 0, 1, 0]
x_B = [0, 1, 0, 1]
x_C = [1, 1, 0, 0] ← Decode 第1步的新 token
x_D = [0, 0, 1, 1] ← Decode 第2步的新 token
计算 K/V(K = x @ W_K,V = x @ W_V):
K_A = x_A @ W_K = [1*1+0*0+1*1+0*0, 1*0+0*1+1*0+0*1] = [2, 0]
K_B = x_B @ W_K = [0*1+1*0+0*1+1*0, 0*0+1*1+0*0+1*1] = [0, 2]
V_A = x_A @ W_V = [1*1+0*0+1*1+0*0, 1*1+0*1+1*0+0*0] = [2, 1]
V_B = x_B @ W_V = [0*1+1*0+0*1+1*0, 0*1+1*1+0*0+1*0] = [0, 1]
存入 KV Cache:
cache_K = [[2, 0], ← K_A
[0, 2]] ← K_B
cache_V = [[2, 1], ← V_A
[0, 1]] ← V_B
Decode 阶段第1步(生成 token C)
x_C = [1, 1, 0, 0](已在 Prefill 阶段给出)
只计算 C 的 K/V(使用与 Prefill 相同的 W_K、W_V):
K_C = x_C @ W_K = [1*1+1*0+0*1+0*0, 1*0+1*1+0*0+0*1] = [1, 1]
V_C = x_C @ W_V = [1*1+1*0+0*1+0*0, 1*1+1*1+0*0+0*0] = [1, 2]
从 Cache 读取历史,拼接:
K_all = [[2, 0], ← K_A(从 Cache 读,不重新计算)
[0, 2], ← K_B(从 Cache 读,不重新计算)
[1, 1]] ← K_C(刚计算)
V_all = [[2, 1], ← V_A(从 Cache 读,不重新计算)
[0, 1], ← V_B(从 Cache 读,不重新计算)
[1, 2]] ← V_C(刚计算)
Q_C = x_C @ W_Q = [1*1+1*1+0*0+0*1, 1*0+1*1+0*1+0*0] = [2, 1]
Attention 分数(Q_C 对所有 K):
score_A = Q_C · K_A / √2 = (2*2 + 1*0) / 1.414 = 4.0 / 1.414 = 2.828
score_B = Q_C · K_B / √2 = (2*0 + 1*2) / 1.414 = 2.0 / 1.414 = 1.414
score_C = Q_C · K_C / √2 = (2*1 + 1*1) / 1.414 = 3.0 / 1.414 = 2.121
softmax([2.828, 1.414, 2.121]):
e^2.828=16.930, e^1.414=4.113, e^2.121=8.337, 总和=29.380
权重 = [16.930/29.380, 4.113/29.380, 8.337/29.380]
= [0.576, 0.140, 0.284]
输出 = 0.576*V_A + 0.140*V_B + 0.284*V_C
= 0.576*[2,1] + 0.140*[0,1] + 0.284*[1,2]
= [1.152, 0.576] + [0.000, 0.140] + [0.284, 0.568]
= [1.436, 1.284]
更新 Cache:
cache_K 追加 K_C → [[2,0], [0,2], [1,1]]
cache_V 追加 V_C → [[2,1], [0,1], [1,2]]
Decode 阶段第2步(生成 token D)
x_D = [0, 0, 1, 1](已在 Prefill 阶段给出)
只计算 D 的 K/V(使用与 Prefill 相同的 W_K、W_V):
K_D = x_D @ W_K = [0*1+0*0+1*1+1*0, 0*0+0*1+1*0+1*1] = [1, 1]
V_D = x_D @ W_V = [0*1+0*0+1*1+1*0, 0*1+0*1+1*0+1*0] = [1, 0]
从 Cache 读取历史(此时 Cache 已包含 A、B、C):
cache_K = [[2, 0], ← K_A(不重新计算)
[0, 2], ← K_B(不重新计算)
[1, 1]] ← K_C(不重新计算)
拼接 K_D:
K_all = [[2, 0], ← K_A(Cache 读取)
[0, 2], ← K_B(Cache 读取)
[1, 1], ← K_C(Cache 读取)
[1, 1]] ← K_D(刚计算)
V_all = [[2, 1], ← V_A(Cache 读取)
[0, 1], ← V_B(Cache 读取)
[1, 2], ← V_C(Cache 读取)
[1, 0]] ← V_D(刚计算)
Q_D = x_D @ W_Q = [0*1+0*1+1*0+1*1, 0*0+0*1+1*1+1*0] = [1, 1]
Attention 分数(Q_D 对所有 K):
score_A = Q_D · K_A / √2 = (1*2 + 1*0) / 1.414 = 2.0 / 1.414 = 1.414
score_B = Q_D · K_B / √2 = (1*0 + 1*2) / 1.414 = 2.0 / 1.414 = 1.414
score_C = Q_D · K_C / √2 = (1*1 + 1*1) / 1.414 = 2.0 / 1.414 = 1.414
score_D = Q_D · K_D / √2 = (1*1 + 1*1) / 1.414 = 2.0 / 1.414 = 1.414
softmax([1.414, 1.414, 1.414, 1.414]):
四个值完全相等 → 权重均等 = [0.25, 0.25, 0.25, 0.25]
(Q_D 对所有历史 token 和自身的关注度相同,符合直觉:x_D 的特征与所有历史 token 距离相等)
输出 = 0.25*V_A + 0.25*V_B + 0.25*V_C + 0.25*V_D
= 0.25*[2,1] + 0.25*[0,1] + 0.25*[1,2] + 0.25*[1,0]
= [0.5, 0.25] + [0.0, 0.25] + [0.25, 0.5] + [0.25, 0.0]
= [1.0, 1.0]
更新 Cache:
cache_K 追加 K_D → [[2,0], [0,2], [1,1], [1,1]]
cache_V 追加 V_D → [[2,1], [0,1], [1,2], [1,0]]
对比:若没有 KV Cache,这一步需要重新计算 K_A、K_B、K_C,共 3 次矩阵乘法;
有了 KV Cache,只计算了 K_D 这 1 次,节省了 75% 的 K/V 计算量。
序列越长,节省比例越高(生成第 T 步时节省 (T-1)/T)。
六、KV Cache 的工程实现要点
预分配显存(静态 Cache)
推理框架通常在开始前就按最大序列长度预分配好 Cache 显存,避免动态分配的开销:
import torch
max_seq_len = 4096
num_layers = 32
num_kv_heads = 8
head_dim = 128
batch_size = 1 # 推理时通常 batch=1(单请求),服务化场景可设更大值
# 预分配 KV Cache,shape: [层数, 2, batch, kv头数, 最大序列长度, 头维度]
# 第2维的 2 代表 K 和 V 各一份(index 0 = K,index 1 = V)
kv_cache = torch.zeros(num_layers, 2, batch_size, num_kv_heads, max_seq_len, head_dim)
print(f"KV Cache 预分配显存:{kv_cache.numel() * 2 / 1024**3:.2f} GB (fp16)")
# 32层 × 2 × 1 × 8头 × 4096 × 128 × 2字节 = 0.5 GB
动态 Cache(PagedAttention,vLLM 使用)
静态预分配的问题:不同请求序列长度不同,预分配最大长度会浪费显存。
PagedAttention(vLLM 提出)借鉴操作系统虚拟内存的分页思想:
- 把 KV Cache 切成固定大小的"页"(block)
- 按需分配,不同请求的 Cache 页可以不连续
- 支持多个请求共享相同 prompt 的 Cache(prefix sharing)
传统静态 Cache:
请求A(实际用200 token):预分配4096 token的显存 → 浪费3896 token
请求B(实际用3000 token):预分配4096 token的显存 → 浪费1096 token
PagedAttention:
请求A:按需分配 200 token 的 Cache,用多少分多少
请求B:按需分配 3000 token 的 Cache
→ 显存利用率大幅提升,同等显存可服务更多并发请求
PyTorch 中使用 KV Cache 的简化示例
import torch
import torch.nn as nn
import math
class AttentionWithKVCache(nn.Module):
def __init__(self, d_model: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
B, T, _ = x.shape
return x.view(B, T, self.num_heads, self.d_k).transpose(1, 2)
def forward(
self,
x: torch.Tensor,
past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
# 只对当前新 token 计算 Q、K、V
Q = self.split_heads(self.W_Q(x)) # [B, h, T_new, d_k]
K_new = self.split_heads(self.W_K(x))
V_new = self.split_heads(self.W_V(x))
# 拼接历史 Cache
if past_key_value is not None:
K_cached, V_cached = past_key_value
K = torch.cat([K_cached, K_new], dim=2) # [B, h, T_all, d_k]
V = torch.cat([V_cached, V_new], dim=2)
else:
K, V = K_new, V_new
# 更新后的 Cache(供下一步使用)
updated_cache = (K, V)
# Scaled Dot-Product Attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
weights = torch.softmax(scores, dim=-1)
attended = torch.matmul(weights, V)
# 合并多头,输出投影
B, h, T_new, d_k = attended.shape
attended = attended.transpose(1, 2).contiguous().view(B, T_new, h * d_k)
output = self.W_O(attended)
return output, updated_cache
# 使用示例:模拟自回归生成
model = AttentionWithKVCache(d_model=64, num_heads=4)
model.eval()
with torch.no_grad():
# Prefill:处理 prompt(5个token)
prompt = torch.randn(1, 5, 64)
out, kv_cache = model(prompt, past_key_value=None)
print(f"Prefill 后 Cache K shape: {kv_cache[0].shape}") # [1, 4, 5, 16]
# Decode:逐步生成新 token
for step in range(3):
new_token = torch.randn(1, 1, 64) # 每次只输入 1 个新 token
out, kv_cache = model(new_token, past_key_value=kv_cache)
print(f"Step {step+1} 后 Cache K shape: {kv_cache[0].shape}")
# Step 1: [1, 4, 6, 16]
# Step 2: [1, 4, 7, 16]
# Step 3: [1, 4, 8, 16]
七、KV Cache 与注意力变体的关系
KV Cache 的显存压力催生了各种注意力变体(详见 MHA 笔记第八节):
KV Cache 太大
↓
MQA(2019):所有头共享 K/V → Cache 缩小 h 倍,但性能下降
↓
GQA(2023):分组共享 K/V → Cache 缩小 h/g 倍,性能接近 MHA
↓
MLA(2024):缓存低秩压缩向量 c → Cache 缩小 ~93%,性能不降反升
Flash Attention 解决的是另一个问题:不是 Cache 太大,而是 Attention 矩阵的 HBM 读写太慢。两者正交,可以同时使用。
八、常见问题
Q:训练时为什么不用 KV Cache? 训练时使用 Teacher Forcing,所有 token 并行输入,不需要逐步生成,因此不存在"历史 K/V 重复计算"的问题。
Q:KV Cache 会不会导致内存溢出(OOM)? 会。序列越长、batch 越大、模型越大,Cache 越大。解决方案:
- 使用 GQA/MLA 减少 Cache 大小
- 使用 PagedAttention(vLLM)提高显存利用率
- 量化 Cache(如 KV Cache int8 量化)
- 设置最大序列长度限制
Q:多轮对话时 KV Cache 怎么处理? 每轮对话结束后,可以选择:
- 保留 Cache:下一轮直接在历史 Cache 上追加,速度快,但显存持续增长
- 清空 Cache:每轮重新 Prefill,显存可控,但每轮都要重新计算历史
Q:为什么 Prefill 快、Decode 慢?
- Prefill:大矩阵乘法,GPU 并行度高,计算密集型,GPU 利用率接近 100%
- Decode:每次只处理 1 个 token,矩阵很小,主要时间花在从显存读 KV Cache,内存带宽密集型,GPU 利用率通常只有 10%~30%
九、一句话总结
KV Cache = 空间换时间:把历史 token 的 K/V 存在显存里,每次生成新 token 只算新的 K/V,历史的直接读取,把推理复杂度从 O(T²) 降为 O(T)。
代价是显存随序列长度线性增长,这是 MQA/GQA/MLA/PagedAttention 等技术存在的根本原因。