llamafactory gradient_checkpointing 梯度检查点 通俗完整讲解
1. 原生不开启时(你现在 false 的状态)
模型走一遍前向传播,会把所有中间激活值全部存在显存里。
作用:反向传播算梯度时,直接拿这些中间值计算,不用重新跑前向,速度快。
代价:序列越长、batch越大,中间激活占的显存爆炸,很容易冲到99%。
举个你场景的例子:cutoff_len=1024、bs=32,14B大模型,中间激活张量体积非常大,这就是你显存经常顶满的核心元凶。
2. 开启 gradient_checkpointing: true 做了什么?
核心逻辑:以少量重复计算,换取大幅显存节省
- 前向传播时,不保存全部中间激活,只存少量关键节点;
- 等到反向传播需要某一段中间数据时,重新再跑一遍对应区间的前向计算,现场算出激活;
- 全程不会一次性堆海量中间张量在显存里,显存峰值直接下降 30%~50%。
和你梯度累积完全无关
不管 gradient_accumulation_steps 是2还是4,梯度检查点只管「前向激活要不要存」,不影响梯度累加逻辑。
3. 优缺点
优点
- 大幅压低显存峰值,解决你 88%~99% 冲高、偶尔OOM崩溃问题;
- 同样显卡下,可以开更大 batch / 更长序列长度;
- 搭配 FlashAttention2 双重省显存,MI300 192GB 体验提升明显。
缺点
- 每一步训练会多跑一小段前向计算,训练速度大概慢 10%~20%;
- CPU算力、读写开销轻微上涨,
preprocessing_num_workers=8完全能扛住。
4. 适配你当前场景怎么选
- 现状:显存经常冲到99%,波动极大
→建议开启 true,稳定性优先,轻微降速完全可接受; - 如果你后续调小batch、且全程显存稳定80以内不报警,再切回 false 提速。
5. 补充关键细节(SFT训练专用)
use_cache: false必须搭配梯度检查点一起用(你配置里已经开了,正确),训练阶段禁用KV缓存,额外省一大块显存;- LoRA微调场景下开启梯度检查点不影响LoRA梯度更新,只冻结基础模型部分重算,LoRA训练效果无损失;
- 和 bf16 / FlashAttention2 兼容,ROCm MI300无兼容性bug。
极简总结
不开:存所有中间激活 → 显存占用高、跑的快
开启:丢掉大部分中间激活,反向时临时重算 → 显存砍半、速度略慢
专门解决你长文本1024序列导致的显存爆满问题。