大模型显存占用分析笔记:模型显存,训练显存,推理显存,显存优化
模型显存,训练显存,推理显存,显存优化
·
【如果笔记对你有帮助,欢迎关注&点赞&收藏,收到正反馈会加快更新!谢谢支持!】
一、模型权重显存占用
- 常用的模型保存格式:fp32、fp16、bf16、int8【数字x表示:x位保存一个模型参数】
- int8: 模型显存 = 1*参数量(Byte)
- fp16, bf16: 模型显存 = 2*参数量(Byte)
- fp32: 模型显存 = 4*参数量(Byte)
- 显存计算(以fp16的7B模型举例):
- 精确计算:显存占用(GB) = 2 * 7 * 10^9 / (1024^3) = 13.038GB
- 估算:2 * 7 = 14GB
二、训练显存占用
- 训练显存消耗 = 模型权重 + 优化器状态 + 梯度 + 激活状态 (+ 临时数据&未知)
- 优化器状态:
- 优化器显存占用 = 主权重(模型副本) + 动量 + 方差 = 3倍模型权重
- 以混合精度(fp16/bf16 + fp32)为例:
- 模型权重 = 2 * 参数量(Byte)【FP/BF16】
- 优化器显存占用 = FP32 主权重 +
参数量×8字节(动量4字节 + 方差4字节)【FP32】
- 梯度:和模型权重一致
- 激活值:和模型参数、重计算、并行策略等相关
- 混合精度可以将激活值显存减少 60%~70%
三、推理阶段显存占用
- 总显存 = 模型权重 + 前向计算开销
- 估算:总显存 = 1.2 * 模型权重
四、显存优化
4.1 核心显存优化技术
- 混合精度训练:前向/反向传播使用FP16/BF16,优化器状态保留FP32主副本
- 梯度累积:累积多个小批次梯度后更新权重,显存降低为累积步数N的倒数
- 梯度检查点(Gradient Checkpointing):仅存储部分层激活值,反向传播时重算中间层,以时间换空间
- 量化技术:如模型参数/激活值从FP16转为INT8
4.2 分布式训练优化
- 数据并行(Data Parallelism): PyTorch DDP
- 张量并行(Tensor Parallelism):单层参数横向拆分(如Megatron-LM)
- 流水线并行(Pipeline Parallelism):模型层纵向拆分,重叠前后向计算(如GPipe)
4.3 推理阶段优化
- KV Cache压缩
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐



所有评论(0)