LLM —— Parameter Calculation
模型参数量计算与效率分析
本节主要参考中国人民大学的《大语言模型》一书的第6.4章节。介绍 Transformer 架构大语言模型的参数量计算、训练运算量、训练时间及显存开销的估算方法。
参数量计算
以 LLaMA 模型为例,假设词表大小 V,解码器层数 L,隐层维度 H,前馈网络中间维度 H’:
- 输入嵌入层:参数矩阵 $E \in \mathbb{R}^{V \times H} → VH$ 参数。
- 多头注意力层:
- 查询 $W^Q$、键 $W^K$、值 $W^V$ 矩阵:各 $H^2$ 参数 → $3H^2$。
- 输出投影矩阵 $W^O \in \mathbb{R}^{H \times H}$ → $H^2$ 参数。
- 前馈网络层:
- 上投影 $W^U, W^G \in \mathbb{R}^{H \times H’}$ → $2HH’$。
- 下投影 $W^D \in \mathbb{R}^{H’ \times H}$ → $HH’$。
- 归一化层:
- 每层 2 个 RMSNorm → $2H \times L$。
- 末层额外归一化 → $H$。
- 输出层:线性变换 $W^L \in \mathbb{R}^{H \times V}$ → $VH$ 参数。
总参数量公式:
$$
\text{参数量} = 2VH + H + L \cdot (4H^2 + 3HH’ + 2H)
$$
示例(LLaMA-7B):
$V=32000, L=32, H=4096, H’=11008$ → 计算得 $6,738,415,616$(与实际一致)。
训练运算量估计
训练运算量以浮点操作数(FLOP)衡量,浮点运算包括浮点数的加减乘除、指数、对数、三角函数等运算。Transformer 架构训练运算量主要在多头注意力和线性变换计算。归一化、输出映射和旋转位置编码计算所需的运算量较少。设定以下参数:模型总参数量为𝑃,批处理大小为𝐵,输入序列长度为𝑇,训练词元总数为𝐶 = 𝐵𝑇;多头注意力机制包含𝑁 个头,每个头的维度为𝐷,𝐻 = 𝑁𝐷。
矩阵乘法 FLOP:$A \in \mathbb{R}^{n \times m} \times B \in \mathbb{R}^{m \times p}$ 需 $2mnp$ FLOP。
多头注意力运算量
- 前向传播:
- $Q,K,V \in \mathbb{R}^{B \times T \times H} $, 拆分和转置后 $Q’,K’,V’ \in \mathbb{R}^{B \times N \times T \times D} $,$Q’K’^\top$ 计算:$2BT^2ND$
- Softmax(含标准化$\sqrt D$放缩、指数、加和、归一化):$4BT^2N$
- 结果与 $V’$ 乘:$2BT^2ND$
- 反向传播约为前向的 2 倍 → 总 FLOP:
$$
12 \cdot (BT^2ND + BT^2N) \cdot L = 12CTL \cdot (H + N)
$$
线性变换运算量
- 前向传播 FLOP = $2BTHH’$。
- 反向传播需 2 倍前向 FLOP → 总 FLOP:
$$
\text{运算量} = 6C \cdot (\text{线性变换参数量})
$$ - 激活重计算时(额外进行一次前向传播):
$$
\text{运算量} = 8C \cdot (\text{线性变换参数量})
$$
简化估算(线性变换参数量占比总参数量 >95%):
$$
\text{总运算量} \approx 6CP \quad \text{或} \quad 8CP
$$
示例:LLaMA-7B($P=6.74 \times 10^9, C=10^9$)→ $6 \times 6.74 \times 10^9 \times 10^9 = 4.04 \times 10^{19}$ FLOP。
训练时间估计
训练时间的计算主要包括浮点数运算、数据读写以及多进程同步等。其中浮点数运算的耗时是训练过程中最主要的部分。实际每秒浮点运算数 FLOPS 通常为理论值的 30%-70%。
$$
\text{训练时间} = \frac{\text{运算量}}{\text{GPU 数量} \times \text{实际FLOPS}}
$$
示例(LLaMA-65B):
- 参数量 $P=6.5 \times 10^{10}$,词元数 $C=1.4 \times 10^{12}$(激活重计算 → $8CP = 7.28 \times 10^{23}$ FLOP)。
- 使用 2048 张 A100(假设单卡实际算力 $2 \times 10^{14}$次BF16的FLOPS),论文公布训练时间为 21 天。
$$
\text{时间} = \frac{7.28 \times 10^{23}}{2048 \times 2 \times 10^{14}} \approx 1.78 \times 10^6 \text{ 秒} \approx 20.6 \text{ 天}。
$$
训练显存估计
显存占用分三部分:模型参数与优化器、激活值、其他开销。
1. 模型参数与优化器显存
现有大模型训练中通常会采用混合精度训练,模型参数和梯度通常以16位浮点数存储,而Adam或AdamW优化器需要额外存储32位浮点数的模型参数、动量参数以及动量二阶矩参数。
优化方案 | 单 GPU 显存占用(字节) | 优化原理 |
---|---|---|
无 ZeRO | $16P$ | 一个16位浮点数需要2字节,一个32位浮点数需要4字节,因此模型参数和梯度各需要2𝑃字节的显存,Adam优化器的模型参数、动量参数以及动量二阶矩参数则各需要4𝑃字节的显存。累和每张GPU 上需要$(2+2+4+4+4)·𝑃 = 16𝑃$字节的显存。 |
ZeRO-1(优化器分区) | $4P + 12P/N_D$ | 优化器参数平摊到每张GPU 上,模型参数和梯度在每张显卡各自保留。每张GPU 上会需要$(2 + 2) · 𝑃 + (4 + 4 + 4) · 𝑃/𝑁_𝐷$字节显存。 |
ZeRO-2(梯度分区) | $2P + 14P/N_D$ | 将模型梯度也平摊到每张GPU 上。每张GPU 上会需要使用$2𝑃 + (2 + 4 + 4 + 4) · 𝑃/𝑁_𝐷$字节的显存。 |
ZeRO-3(参数分区) | $16P/N_D$ | 将模型参数也平摊到每张GPU 上。每张GPU需要使用$16𝑃/𝑁_𝐷$ 字节显存。 |
($N_D$:数据并行 GPU 数;张量/流水线并行时需额外除以 $N_T \times N_P$)。
2. 激活值显存
存储前向传播的中间结果用于反向传播梯度计算
1. 多头自注意力层
输入矩阵$X$ + $QKV$ + $QK^T$ + concat
$$ 3×2BTH + 3×2BTH + 2BT²N + 2BTH = 16BTH + 2BT²N $$
2. 前馈网络层(SwiGLU 结构)
网络输入$X$ + 门控线性变换$W^GX$和$W^UX$ + 激活输出$\sigma$
$$
2BTH + 2\times 2BTH’ + 2BTH’ = 2BTH + 6BTH’
$$
3. 归一化层
注意力层前 RMSNorm + 前馈网络层前 RMSNorm:
$$
2BTH + 2BTH = 4BTH
$$
4. 输出层
最终层归一化 + 末层解码器输出 $Y_L$ + Softmax:
$$
2BTH + 2BTH + 4BTV = 4BTH + 4BTV
$$
总激活值显存公式
$$
\text{激活值} = (16BTH + 6BTH’ + 2BT^2N) \times L + 4BTH + 4BTV
$$
优化技术对激活值的影响
技术 | 激活值显存公式 | 优化原理 |
---|---|---|
FlashAttention | 移除公式中的 $2BT^2N$ 项 | 避免存储 $QK^\top$ 矩阵 |
流水线并行 ($N_P$) | $(16BTH + 6BTH' + 2BT^2N) \times \frac{L}{N_P} + 4BTH + 4BTV $ | 层拆分到不同 GPU |
张量并行 ($N_T$) | $\left[\left(8 + \frac{8}{N_T}\right)BTH + \frac{6BTH'}{N_T} + \frac{2BT^2N}{N_T}\right]L + 4BTH + 4BTV $ | 矩阵运算拆分到 GPU |
激活重计算 | $(4 + 2L)BTH + 4BTV $ | 前向传播时仅保存每层的输入和最后层softmax的输入,反向传播时按需重计算激活值 |
3. 其他显存占用
固定开销:
- 框架内核:PyTorch 加载需 0.8–1 GB
- ZeRO实现:DeepSpeed 库占用 1–4 GB(随优化等级增加)
- 显存碎片:分配不连续导致,约 0.5–1 GB
动态开销:
- Softmax中间结果:计算时需要32位精度 → 占用 $8BTV$ 字节
综合计算示例:LLaMA-7B 训练显存
训练配置
- GPU:2 × A800 (80GB)
- 并行策略:ZeRO-3 + FlashAttention + 激活重计算
- 超参数:$L=32, H=4096, V=32000, B=8, T=2048$
显存分解计算
参数与优化器(ZeRO-3):
$$
\frac{16 \times 6.74 \times 10^9}{2} = 5.39 \times 10^{10} \text{ 字节} \approx 50.20 \text{ GB}
$$激活值(FlashAttention):
$$
(4 + 2 \times 32) \times 8 \times 2048 \times 4096 + 4 \times 8 \times 2048 \times 32000 = 6.66 \times 10^9 \text{ 字节} \approx 6.20 \text{ GB}
$$Softmax中间结果:
$$
8 \times 8 \times 2048 \times 32000 = 4.19 \times 10^9 \text{ 字节} \approx 3.91 \text{ GB}
$$其他开销:
- 框架 1 GB + ZeRO 2 GB + 碎片 1 GB + 预留 2 GB ≈ 6 GB
总占用/GPU:$50.20 + 6.20 + 3.91 + 6 = 66.31 \text{ GB}$
训练资源配置建议
模型规模 | 参数量 | 最低显存需求 | 推荐GPU配置 | 批次大小 $B$ |
---|---|---|---|---|
13B | 13×10⁹ | 208 GB | 3 × 80GB GPU | ≤2(效率低) |
4 × 80GB GPU | 12 | |||
30B | 30×10⁹ | 480 GB | 8 × 80GB GPU | 按需调整 |
65B | 65×10⁹ | 1040 GB | 16 × 80GB GPU | 按需调整 |
模型参数与优化器占用主导显存需求,需至少 16倍参数量 的显存资源。并行技术和ZeRO优化可显著降低单卡负载。