大模型显存占用由以下几部分组成:

1. 模型本身参数,假设是1个单位

2.模型的梯度,同样也是一个单位

3.优化器参数(占大头):以Adam参数为例,还需要在显卡中额外存储mv两个参数,因此为2个单位参数

4.模型的中间计算结果,因为反向传播求导时会用到,需要存储每一层的输入x(下图以Transformer中的全连接层为例,每一个全连接层的输入参数维度为[batch, 句子长度, 每个token维度]

 以11B大小模型为例,其模型参数占据显存大小就为40GB,再加上其余三个部分后显存花销更大

 

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐