一句话:优化器决定"用梯度怎么更新参数"。从最朴素的 SGD,到加了动量的 SGD Momentum,再到自适应学习率的 Adam/AdamW,每一步改进都在解决上一代的具体问题。
一、梯度下降的基本框架
所有优化器的核心都是:
$$\theta_{t+1} = \theta_t - \text{update}(g_t)$$
其中 $g_t = \nabla_\theta \mathcal{L}$ 是当前 batch 的梯度,不同优化器只是 $\text{update}$ 的计算方式不同。
三种梯度下降变体(按 batch 大小区分):
批量梯度下降(BGD): 用全部数据算梯度,准确但极慢
随机梯度下降(SGD): 用 1 条数据算梯度,快但噪声大
小批量梯度下降(MSGD):用 mini-batch 算梯度,实践中的标准做法
现代所有"SGD"都指小批量梯度下降。
二、SGD:最朴素的优化器
$$\theta_{t+1} = \theta_t - \eta \cdot g_t$$
- $\eta$:学习率(步长)
- $g_t$:当前 mini-batch 的梯度
问题:
- 学习率敏感:太大震荡,太小收敛慢
- 各维度步长相同:不同参数的梯度尺度可能差异极大,一个学习率无法同时适配所有参数
- 容易卡在鞍点:梯度为零但不是极值点
三、SGD + Momentum:加入惯性
$$v_t = \beta \cdot v_{t-1} + g_t$$ $$\theta_{t+1} = \theta_t - \eta \cdot v_t$$
- $v_t$:速度(历史梯度的指数加权平均)
- $\beta$:动量系数,通常为 0.9
直觉:像一个球滚下山坡,历史的速度会叠加到当前更新上。
解决的问题:
- 方向一致的维度加速收敛(速度累积)
- 方向震荡的维度相互抵消(震荡被平滑)
- 能越过较小的局部极值和鞍点
四、Adam:自适应学习率
Adam(Adaptive Moment Estimation,2014,Kingma & Ba)的核心思想:为每个参数单独维护一个自适应学习率。
完整公式
$$m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t \quad \text{(一阶动量,梯度均值)}$$ $$v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2 \quad \text{(二阶动量,梯度方差)}$$
偏差修正(前几步 $m$、$v$ 偏小,需要修正):
$$\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t}$$
参数更新:
$$\theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \cdot \hat{m}_t$$
默认超参数:$\beta_1 = 0.9$,$\beta_2 = 0.999$,$\epsilon = 10^{-8}$,$\eta = 10^{-3}$
为什么 Adam 好
某个参数的梯度一直很大(如 embedding 层的高频词):
→ v_t 很大 → 分母大 → 步长小 → 不会更新过猛
某个参数的梯度一直很小(如深层的稀疏参数):
→ v_t 很小 → 分母小 → 步长大 → 不会更新太慢
效果:自动为每个参数找到合适的步长,对学习率的选择不那么敏感。
AdamW:修复权重衰减
标准 Adam 的权重衰减实现有 bug:L2 正则化会被自适应学习率缩放,导致正则化效果不稳定。
AdamW 把权重衰减从梯度里解耦出来,直接作用于参数:
$$\theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \cdot \hat{m}_t - \eta \lambda \theta_t$$
其中 $\lambda$ 是权重衰减系数(通常 0.01~0.1)。LLM 训练几乎都用 AdamW。
五、Adam 的显存开销
这是 Adam 最大的缺点:需要为每个参数额外存储两个状态($m_t$ 和 $v_t$),全部用 FP32。
对于参数量为 $N$ 的模型:
模型权重(BF16):N × 2 bytes
梯度(BF16): N × 2 bytes
m_t(FP32): N × 4 bytes
v_t(FP32): N × 4 bytes
─────────────────────────
合计: N × 12 bytes
以 7B 模型为例:$7 \times 10^9 \times 12 \approx 84\ \text{GB}$,光优化器状态就占 56 GB。
Adam 8bit(bitsandbytes) 的解决方案:把 $m_t$、$v_t$ 量化到 8bit,显存从 $N \times 12$ 降到 $N \times 6$ bytes,节省约 50%。
六、MNIST 完整示例:对比 SGD、Momentum、Adam
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time
class MnistNet(nn.Module):
"""
简单的 3 层全连接网络,用于 MNIST 分类。
输入:28×28 = 784 维,输出:10 类
"""
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.view(x.size(0), -1) # [batch, 1, 28, 28] → [batch, 784]
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x) # 返回 logits,不做 softmax(CrossEntropyLoss 内部处理)
def train_one_epoch(
model: nn.Module,
optimizer: torch.optim.Optimizer,
loader: DataLoader,
device: torch.device,
) -> tuple[float, float]:
"""训练一个 epoch,返回 (平均 loss, 准确率)。"""
model.train()
total_loss = 0.0
total_correct = 0
total_samples = 0
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
logits = model(images)
loss = F.cross_entropy(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * images.size(0)
total_correct += (logits.argmax(dim=1) == labels).sum().item()
total_samples += images.size(0)
return total_loss / total_samples, total_correct / total_samples
def evaluate(
model: nn.Module,
loader: DataLoader,
device: torch.device,
) -> float:
"""在测试集上评估准确率。"""
model.eval()
total_correct = 0
total_samples = 0
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
logits = model(images)
total_correct += (logits.argmax(dim=1) == labels).sum().item()
total_samples += images.size(0)
return total_correct / total_samples
def count_optimizer_memory_bytes(optimizer: torch.optim.Optimizer) -> int:
"""估算优化器状态占用的字节数。"""
total_bytes = 0
for group in optimizer.param_groups:
for param in group["params"]:
state = optimizer.state[param]
for tensor in state.values():
if isinstance(tensor, torch.Tensor):
total_bytes += tensor.numel() * tensor.element_size()
return total_bytes
def run_experiment(
optimizer_name: str,
optimizer: torch.optim.Optimizer,
num_epochs: int,
train_loader: DataLoader,
test_loader: DataLoader,
device: torch.device,
model: nn.Module,
) -> None:
print("=" * 50)
print("Optimizer: %s" % optimizer_name)
print("=" * 50)
start_time = time.time()
for epoch in range(1, num_epochs + 1):
train_loss, train_acc = train_one_epoch(model, optimizer, train_loader, device)
test_acc = evaluate(model, test_loader, device)
print(" Epoch %2d | loss=%.4f | train_acc=%.4f | test_acc=%.4f" % (
epoch, train_loss, train_acc, test_acc))
elapsed = time.time() - start_time
# 统计优化器显存(只有 Adam 在第一步后才会分配状态)
optimizer_mem_bytes = count_optimizer_memory_bytes(optimizer)
model_param_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
print(" Time: %.1fs" % elapsed)
print(" Model params: %.1f KB" % (model_param_bytes / 1024))
print(" Optimizer state: %.1f KB" % (optimizer_mem_bytes / 1024))
if model_param_bytes > 0:
print(" Opt/Param ratio: %.1fx" % (optimizer_mem_bytes / model_param_bytes))
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 数据集下载路径(兼容 Windows 和 Linux/Mac)
import os
mnist_root = os.path.join(os.path.expanduser("~"), "tmp", "mnist")
os.makedirs(mnist_root, exist_ok=True)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)), # MNIST 的均值和标准差
])
train_dataset = datasets.MNIST(root=mnist_root, train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root=mnist_root, train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=0)
num_epochs = 5
# ── 实验 1:SGD(无动量)──────────────────────────────────────────────────
model_sgd = MnistNet().to(device)
optimizer_sgd = torch.optim.SGD(model_sgd.parameters(), lr=0.01)
run_experiment("SGD (lr=0.01)", optimizer_sgd, num_epochs,
train_loader, test_loader, device, model_sgd)
# ── 实验 2:SGD + Momentum ────────────────────────────────────────────────
model_momentum = MnistNet().to(device)
optimizer_momentum = torch.optim.SGD(model_momentum.parameters(), lr=0.01, momentum=0.9)
run_experiment("SGD + Momentum (lr=0.01, β=0.9)", optimizer_momentum, num_epochs,
train_loader, test_loader, device, model_momentum)
# ── 实验 3:Adam ──────────────────────────────────────────────────────────
model_adam = MnistNet().to(device)
optimizer_adam = torch.optim.Adam(model_adam.parameters(), lr=0.001)
run_experiment("Adam (lr=0.001)", optimizer_adam, num_epochs,
train_loader, test_loader, device, model_adam)
# ── 实验 4:AdamW ─────────────────────────────────────────────────────────
model_adamw = MnistNet().to(device)
optimizer_adamw = torch.optim.AdamW(model_adamw.parameters(), lr=0.001, weight_decay=0.01)
run_experiment("AdamW (lr=0.001, wd=0.01)", optimizer_adamw, num_epochs,
train_loader, test_loader, device, model_adamw)
# ── 显存汇总 ──────────────────────────────────────────────────────────────
print(f"\n{'='*50}")
print("显存占用汇总(MNIST 小模型,参数量约 200K)")
print(f"{'='*50}")
model_param_bytes = sum(p.numel() * p.element_size() for p in model_adam.parameters())
print(f"模型参数(FP32):{model_param_bytes / 1024:.0f} KB")
for name, opt in [("SGD", optimizer_sgd), ("SGD+Momentum", optimizer_momentum),
("Adam", optimizer_adam), ("AdamW", optimizer_adamw)]:
opt_bytes = count_optimizer_memory_bytes(opt)
print(f"{name:20s}:优化器状态 {opt_bytes/1024:.0f} KB(模型参数的 {opt_bytes/model_param_bytes:.1f}x)")
七、各优化器显存对比(理论)
对参数量为 $N$ 的模型,以 FP32 存储为基准(4 bytes/参数):
| 优化器 | 额外存储的状态 | 额外显存 | 总显存(含参数) |
|---|---|---|---|
| SGD | 无 | 0 | $N \times 4$ bytes |
| SGD + Momentum | 速度 $v$(FP32) | $N \times 4$ | $N \times 8$ bytes |
| Adam / AdamW | $m$(FP32)+ $v$(FP32) | $N \times 8$ | $N \times 12$ bytes |
| Adam 8bit | $m$(INT8)+ $v$(INT8) | $N \times 2$ | $N \times 6$ bytes |
以 7B 模型(BF16 权重)为例:
模型权重(BF16): 7B × 2 = 14 GB
梯度(BF16): 7B × 2 = 14 GB
Adam m_t(FP32): 7B × 4 = 28 GB
Adam v_t(FP32): 7B × 4 = 28 GB
─────────────────────────────────────
全量微调合计: 84 GB
Adam 8bit 的优化器状态:7B × 1 + 7B × 1 = 14 GB(节省 42 GB)
LoRA 只对 ~1% 参数更新:Adam 状态 ≈ 0.56 GB(几乎可忽略)
八、核心要点速查
| 问题 | 答案 |
|---|---|
| SGD 的缺点? | 各参数步长相同,对学习率敏感,容易震荡或卡鞍点 |
| Momentum 解决什么? | 用历史梯度的指数平均平滑更新,加速一致方向、抑制震荡方向 |
| Adam 的两个动量是什么? | $m_t$:梯度均值(一阶);$v_t$:梯度方差(二阶) |
| 为什么要偏差修正? | 初始 $m_0=v_0=0$,前几步估计偏小,除以 $(1-\beta^t)$ 修正 |
| AdamW 和 Adam 的区别? | AdamW 把权重衰减从梯度解耦,正则化效果更稳定 |
| Adam 为什么显存贵? | 需要额外存 $m_t$ 和 $v_t$,均为 FP32,是参数量的 2 倍显存 |
| Adam 8bit 怎么省显存? | 把 $m_t$、$v_t$ 量化到 8bit,优化器状态显存减少 75% |
| LLM 微调用哪个? | adamw_8bit(bitsandbytes 提供),显存省、效果好 |