为什么LLM中使用BF16而不是FP16?

结论:BF16 的指数位与 FP32 相同(8位),与 FP32 动态范围相当,避免LLM训练中梯度下溢问题,现代加速器(如 TPU、NVIDIA Ampere)原生支持 BF16,具有硬件兼容性,更适合 LLM 的数值稳定性需求。


混合精度训练

早期预训练语言模型(如BERT)主要使用单精度浮点数(FP32)表示模型参数并进行优化计算。为训练大语言模型研究人员提出混合精度训练(Mixed Precision Training),同时使用半精度浮点数(2个字节)和单精度浮点数(4个字节)运算,实现显存开销减半、训练效率翻倍效果。

半精度浮点

1. FP16(半精度浮点)

  • 标准:IEEE 754 标准,16位(1位符号 + 5位指数 + 10位尾数)。
  • 特点
    • 数值范围较小,约$[-65504, 65504]$。
    • 易发生 下溢(Underflow)溢出(Overflow)
    • 需要 损失缩放(Loss Scaling) 防止梯度消失。

2. BF16(脑浮点)

  • 标准:由 Google 提出,16位(1位符号 + 8位指数 + 7位尾数)。
  • 特点
    • 动态范围与 FP32 相当,约$[-3.39 \times 10^{38}, 3.39 \times 10^{38}]$。
    • 精度较低,但对大模型更友好(减少梯度下溢风险)。
    • 无需复杂的损失缩放,硬件兼容性逐步提升(如 Intel CPU、TPU、NVIDIA Ampere 架构)。

3. FP16 vs BF16

特性 FP16 BF16
数值范围 窄(易溢出) 宽(接近 FP32)
精度 高(尾数位多) 低(尾数位少)
硬件支持 NVIDIA Volta/Turing NVIDIA Ampere、TPU
适用场景 图像、小模型 大模型、NLP 任务

原理

  1. 参数以 FP32 存储。
  2. 前向计算:框架自动将 FP32 参数转换为 FP16/BF16(通过 autocast)。
  3. 损失计算:在 FP16/BF16 中计算损失,通过 Loss Scaling 放大损失值,避免反向梯度过小而产生下溢(如 scaler.scale(loss))。
  4. 反向计算:基于 FP16/BF16 梯度,放大后的损失值生成 FP16/BF16 梯度。
  5. 梯度转换:框架自动将 FP16/BF16 梯度 cast 到 FP32。
  6. 梯度处理:除以 Loss Scale 还原梯度,检查溢出,如溢出则跳过更新,否则优化器以FP32对原始参数进行更新。

代码示例

PyTorch 混合精度训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from torch.cuda.amp import GradScaler, autocast

# 初始化模型、优化器
model = torch.nn.Linear(100, 10).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
scaler = GradScaler() # 用于 FP16 自动损失缩放

# 训练循环
for input, target in dataloader:
optimizer.zero_grad()

# 前向计算(自动转为 FP16/BF16)
with autocast(dtype=torch.float16): # 或 torch.bfloat16
output = model(input)
loss = loss_fn(output, target) # 损失在 FP16/BF16 中计算

# 反向传播(自动缩放梯度)
scaler.scale(loss).backward() # 梯度在 FP16/BF16 中计算,自动 cast 到 FP32
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # 梯度裁剪
scaler.step(optimizer) # 自动除以 scale 并更新 FP32 参数
scaler.update() # 动态调整 scale 系数

DeepSpeed 混合精度配置

1
2
3
4
5
6
7
8
9
10
11
# 在 DeepSpeed 配置文件中设置
{
"train_batch_size": 32,
"fp16": {
"enabled": true,
"loss_scale": 0 # 0 表示动态损失缩放
},
"bf16": {
"enabled": false # 切换 BF16 时设为 true
}
}

实践技巧

  1. 数值稳定性
    • 监控梯度值,避免 NaN(可通过 torch.isnan() 或日志检查)。
    • 对 RNN 类模型需谨慎使用 FP16。
  2. 显存优化
    • 混合精度可降低显存占用约 30-50%,支持更大 batch size。
问题 解决方案
梯度下溢/溢出 启用动态损失缩放(如 GradScaler
精度下降 调整损失缩放系数或切换 BF16
硬件不支持 BF16 回退到 FP16 或使用软件模拟(慢)