本文共 1597 字,大约阅读时间需要 5 分钟。
数据whale干货:大模型显存优化与计算方法
作者:kaiyuan,来源:知乎
摘要
本文详细分析了大模型训练/推理过程中显存的使用情况,探讨了模型参数存储、优化器状态、激活值、梯度值等关键部分对显存消耗的影响,并提出了多种优化方法和计算公式。通过这些方法,可以有效降低大模型训练的显存需求。在大模型的训练或推理过程中,显存的分配主要分为两部分:一部分用于AI框架的运行,另一部分用于系统的底层驱动。从用户侧可以通过显存可视化工具(如PyTorch的显存分析功能)观察训练过程中显存的使用情况。
由于未知数据的影响,显存估算值与实际测量值可能存在较大差异(误差可超过30%)。例如,估算值为50GB,而实际测试值可能达到75GB。
训练过程中,显存消耗主要由以下四部分组成:
模型参数(Model Memory)
模型参数的存储大小与参数量和数据类型有关。例如,1B模型采用fp32存储,约需要3.725GB,LLama13b模型存储空间约为52GB。优化器状态(Optimizer status)
Adam优化器的状态包括模型副本、Momentum参数和Variance参数。对于8位优化器,存储空间占用减少为4→1→1Bytes。梯度值(Gradient)
梯度值的存储大小与模型数据类型一致,例如fp32梯度占用为4Bytes。激活值(Activation)
激活值的存储大小与模型结构和并行策略有关。Megtron论文提供了计算公式: [ \text{激活值显存} = s \times b \times h \times a \times t \times \lambda ] 其中,s为序列长度,b为微批量大小,h为隐藏层大小,a为注意力头数,t为Tensor并行数值,L为模型层数,λ为精度系数。为了降低单卡显存消耗,常用的并行策略包括TensorParallel(TP)、SequenceParallel(SP)、PipelineParallel(PP)和Zero方法。
3D并行
通过将模型、激活值和梯度分割到不同GPU,显存占用可降低至原来的1/3。公式为: [ \text{优化后的显存} = \frac{\text{单卡显存}}{N} \times (1 - \frac{\text{模型参数}}{\text{总参数}}) ]重计算(Recomputation)
通过丢弃中间数据,显存占用可降低至原来的1/2。公式为: [ \text{激活值显存} = s \times b \times h \times a \times \lambda ]Zero方法
Zero1、Zero2和Zero3通过减少GPU上的参数存储,显存占用可降低至原来的1/8。公式为: [ \text{Zero方法显存} = N \times \frac{\text{参数总量}}{L} ]推理的显存估算公式较为简单,主要包含以下部分: [ \text{总显存} = \text{模型参数} + \text{激活值} + \text{输出数据} ]
显存优化可以通过以下方法实现:
通过合理搭配以上方法,可以有效降低大模型训练的显存需求,为模型规模的提升提供保障。
点赞三连↓
转载地址:http://olwyz.baihongyu.com/