模型参数量计算与效率分析

本节主要参考中国人民大学的《大语言模型》一书的第6.4章节。介绍 Transformer 架构大语言模型的参数量计算、训练运算量、训练时间及显存开销的估算方法。


参数量计算

以 LLaMA 模型为例,假设词表大小 V,解码器层数 L,隐层维度 H,前馈网络中间维度 H’:

  1. 输入嵌入层:参数矩阵 $E \in \mathbb{R}^{V \times H} → VH$ 参数。
  2. 多头注意力层
    • 查询 $W^Q$、键 $W^K$、值 $W^V$ 矩阵:各 $H^2$ 参数 → $3H^2$。
    • 输出投影矩阵 $W^O \in \mathbb{R}^{H \times H}$ → $H^2$ 参数。
    总计:$4H^2$ 参数。
  3. 前馈网络层
    • 上投影 $W^U, W^G \in \mathbb{R}^{H \times H’}$ → $2HH’$。
    • 下投影 $W^D \in \mathbb{R}^{H’ \times H}$ → $HH’$。
    总计:$3HH’$ 参数。
  4. 归一化层
    • 每层 2 个 RMSNorm → $2H \times L$。
    • 末层额外归一化 → $H$。
  5. 输出层:线性变换 $W^L \in \mathbb{R}^{H \times V}$ → $VH$ 参数。

总参数量公式
$$
\text{参数量} = 2VH + H + L \cdot (4H^2 + 3HH’ + 2H)
$$

示例(LLaMA-7B)
$V=32000, L=32, H=4096, H’=11008$ → 计算得 $6,738,415,616$(与实际一致)。


训练运算量估计

训练运算量以浮点操作数(FLOP)衡量,浮点运算包括浮点数的加减乘除、指数、对数、三角函数等运算。Transformer 架构训练运算量主要在多头注意力和线性变换计算。归一化、输出映射和旋转位置编码计算所需的运算量较少。设定以下参数:模型总参数量为𝑃,批处理大小为𝐵,输入序列长度为𝑇,训练词元总数为𝐶 = 𝐵𝑇;多头注意力机制包含𝑁 个头,每个头的维度为𝐷,𝐻 = 𝑁𝐷。

矩阵乘法 FLOP:$A \in \mathbb{R}^{n \times m} \times B \in \mathbb{R}^{m \times p}$ 需 $2mnp$ FLOP。

多头注意力运算量

  • 前向传播:
    • $Q,K,V \in \mathbb{R}^{B \times T \times H} $, 拆分和转置后 $Q’,K’,V’ \in \mathbb{R}^{B \times N \times T \times D} $,$Q’K’^\top$ 计算:$2BT^2ND$
    • Softmax(含标准化$\sqrt D$放缩、指数、加和、归一化):$4BT^2N$
    • 结果与 $V’$ 乘:$2BT^2ND$
    总计:$4BT^2ND + 4BT^2N$。
  • 反向传播约为前向的 2 倍 → 总 FLOP
    $$
    12 \cdot (BT^2ND + BT^2N) \cdot L = 12CTL \cdot (H + N)
    $$

线性变换运算量

  • 前向传播 FLOP = $2BTHH’$。
  • 反向传播需 2 倍前向 FLOP → 总 FLOP
    $$
    \text{运算量} = 6C \cdot (\text{线性变换参数量})
    $$
  • 激活重计算时(额外进行一次前向传播):
    $$
    \text{运算量} = 8C \cdot (\text{线性变换参数量})
    $$

简化估算(线性变换参数量占比总参数量 >95%):
$$
\text{总运算量} \approx 6CP \quad \text{或} \quad 8CP
$$

示例:LLaMA-7B($P=6.74 \times 10^9, C=10^9$)→ $6 \times 6.74 \times 10^9 \times 10^9 = 4.04 \times 10^{19}$ FLOP。


训练时间估计

训练时间的计算主要包括浮点数运算、数据读写以及多进程同步等。其中浮点数运算的耗时是训练过程中最主要的部分。实际每秒浮点运算数 FLOPS 通常为理论值的 30%-70%。
$$
\text{训练时间} = \frac{\text{运算量}}{\text{GPU 数量} \times \text{实际FLOPS}}
$$

示例(LLaMA-65B)

  • 参数量 $P=6.5 \times 10^{10}$,词元数 $C=1.4 \times 10^{12}$(激活重计算 → $8CP = 7.28 \times 10^{23}$ FLOP)。
  • 使用 2048 张 A100(假设单卡实际算力 $2 \times 10^{14}$次BF16的FLOPS),论文公布训练时间为 21 天。
    $$
    \text{时间} = \frac{7.28 \times 10^{23}}{2048 \times 2 \times 10^{14}} \approx 1.78 \times 10^6 \text{ 秒} \approx 20.6 \text{ 天}。
    $$

训练显存估计

显存占用分三部分:模型参数与优化器激活值其他开销

1. 模型参数与优化器显存

现有大模型训练中通常会采用混合精度训练,模型参数和梯度通常以16位浮点数存储,而Adam或AdamW优化器需要额外存储32位浮点数的模型参数、动量参数以及动量二阶矩参数。

优化方案 单 GPU 显存占用(字节) 优化原理
无 ZeRO $16P$ 一个16位浮点数需要2字节,一个32位浮点数需要4字节,因此模型参数和梯度各需要2𝑃字节的显存,Adam优化器的模型参数、动量参数以及动量二阶矩参数则各需要4𝑃字节的显存。累和每张GPU 上需要$(2+2+4+4+4)·𝑃 = 16𝑃$字节的显存。
ZeRO-1(优化器分区) $4P + 12P/N_D$ 优化器参数平摊到每张GPU 上,模型参数和梯度在每张显卡各自保留。每张GPU 上会需要$(2 + 2) · 𝑃 + (4 + 4 + 4) · 𝑃/𝑁_𝐷$字节显存。
ZeRO-2(梯度分区) $2P + 14P/N_D$ 将模型梯度也平摊到每张GPU 上。每张GPU 上会需要使用$2𝑃 + (2 + 4 + 4 + 4) · 𝑃/𝑁_𝐷$字节的显存。
ZeRO-3(参数分区) $16P/N_D$ 将模型参数也平摊到每张GPU 上。每张GPU需要使用$16𝑃/𝑁_𝐷$ 字节显存。

($N_D$:数据并行 GPU 数;张量/流水线并行时需额外除以 $N_T \times N_P$)。

2. 激活值显存

存储前向传播的中间结果用于反向传播梯度计算

1. 多头自注意力层
输入矩阵$X$ + $QKV$ + $QK^T$ + concat
$$ 3×2BTH + 3×2BTH + 2BT²N + 2BTH = 16BTH + 2BT²N $$

2. 前馈网络层(SwiGLU 结构)
网络输入$X$ + 门控线性变换$W^GX$和$W^UX$ + 激活输出$\sigma$
$$
2BTH + 2\times 2BTH’ + 2BTH’ = 2BTH + 6BTH’
$$

3. 归一化层
注意力层前 RMSNorm + 前馈网络层前 RMSNorm:
$$
2BTH + 2BTH = 4BTH
$$

4. 输出层
最终层归一化 + 末层解码器输出 $Y_L$ + Softmax:
$$
2BTH + 2BTH + 4BTV = 4BTH + 4BTV
$$

总激活值显存公式
$$
\text{激活值} = (16BTH + 6BTH’ + 2BT^2N) \times L + 4BTH + 4BTV
$$

优化技术对激活值的影响

技术 激活值显存公式 优化原理
FlashAttention 移除公式中的 $2BT^2N$ 项 避免存储 $QK^\top$ 矩阵
流水线并行 ($N_P$) $(16BTH + 6BTH' + 2BT^2N) \times \frac{L}{N_P} + 4BTH + 4BTV $ 层拆分到不同 GPU
张量并行 ($N_T$) $\left[\left(8 + \frac{8}{N_T}\right)BTH + \frac{6BTH'}{N_T} + \frac{2BT^2N}{N_T}\right]L + 4BTH + 4BTV $ 矩阵运算拆分到 GPU
激活重计算 $(4 + 2L)BTH + 4BTV $ 前向传播时仅保存每层的输入和最后层softmax的输入,反向传播时按需重计算激活值

3. 其他显存占用

固定开销

  • 框架内核:PyTorch 加载需 0.8–1 GB
  • ZeRO实现:DeepSpeed 库占用 1–4 GB(随优化等级增加)
  • 显存碎片:分配不连续导致,约 0.5–1 GB

动态开销

  • Softmax中间结果:计算时需要32位精度 → 占用 $8BTV$ 字节

综合计算示例:LLaMA-7B 训练显存

训练配置

  • GPU:2 × A800 (80GB)
  • 并行策略:ZeRO-3 + FlashAttention + 激活重计算
  • 超参数:$L=32, H=4096, V=32000, B=8, T=2048$

显存分解计算

  1. 参数与优化器(ZeRO-3):
    $$
    \frac{16 \times 6.74 \times 10^9}{2} = 5.39 \times 10^{10} \text{ 字节} \approx 50.20 \text{ GB}
    $$

  2. 激活值(FlashAttention):
    $$
    (4 + 2 \times 32) \times 8 \times 2048 \times 4096 + 4 \times 8 \times 2048 \times 32000 = 6.66 \times 10^9 \text{ 字节} \approx 6.20 \text{ GB}
    $$

  3. Softmax中间结果
    $$
    8 \times 8 \times 2048 \times 32000 = 4.19 \times 10^9 \text{ 字节} \approx 3.91 \text{ GB}
    $$

  4. 其他开销

    • 框架 1 GB + ZeRO 2 GB + 碎片 1 GB + 预留 2 GB ≈ 6 GB

总占用/GPU:$50.20 + 6.20 + 3.91 + 6 = 66.31 \text{ GB}$

训练资源配置建议

模型规模 参数量 最低显存需求 推荐GPU配置 批次大小 $B$
13B 13×10⁹ 208 GB 3 × 80GB GPU ≤2(效率低)
4 × 80GB GPU 12
30B 30×10⁹ 480 GB 8 × 80GB GPU 按需调整
65B 65×10⁹ 1040 GB 16 × 80GB GPU 按需调整

模型参数与优化器占用主导显存需求,需至少 16倍参数量 的显存资源。并行技术和ZeRO优化可显著降低单卡负载。