博客
关于我
快手二面拷打:训练100B模型要多少显存?
阅读量:438 次
发布时间:2019-03-06

本文共 1597 字,大约阅读时间需要 5 分钟。

数据whale干货:大模型显存优化与计算方法

作者:kaiyuan,来源:知乎

摘要

本文详细分析了大模型训练/推理过程中显存的使用情况,探讨了模型参数存储、优化器状态、激活值、梯度值等关键部分对显存消耗的影响,并提出了多种优化方法和计算公式。通过这些方法,可以有效降低大模型训练的显存需求。


模型显存内容分析

在大模型的训练或推理过程中,显存的分配主要分为两部分:一部分用于AI框架的运行,另一部分用于系统的底层驱动。从用户侧可以通过显存可视化工具(如PyTorch的显存分析功能)观察训练过程中显存的使用情况。

显存消耗组成

  • 可估算值:模型参数、优化器状态、激活值、梯度值、输出数据。
  • 未命名数据:临时变量、未知数据。
  • 其他:自动梯度数据。
  • 显存消耗差异

    由于未知数据的影响,显存估算值与实际测量值可能存在较大差异(误差可超过30%)。例如,估算值为50GB,而实际测试值可能达到75GB。


    计算公式与优化方法

    训练场景

    训练过程中,显存消耗主要由以下四部分组成:

  • 模型参数(Model Memory)

    模型参数的存储大小与参数量和数据类型有关。例如,1B模型采用fp32存储,约需要3.725GB,LLama13b模型存储空间约为52GB。

  • 优化器状态(Optimizer status)

    Adam优化器的状态包括模型副本、Momentum参数和Variance参数。对于8位优化器,存储空间占用减少为4→1→1Bytes。

  • 梯度值(Gradient)

    梯度值的存储大小与模型数据类型一致,例如fp32梯度占用为4Bytes。

  • 激活值(Activation)

    激活值的存储大小与模型结构和并行策略有关。Megtron论文提供了计算公式:
    [ \text{激活值显存} = s \times b \times h \times a \times t \times \lambda ]
    其中,s为序列长度,b为微批量大小,h为隐藏层大小,a为注意力头数,t为Tensor并行数值,L为模型层数,λ为精度系数。

  • 并行计算与显存优化

    为了降低单卡显存消耗,常用的并行策略包括TensorParallel(TP)、SequenceParallel(SP)、PipelineParallel(PP)和Zero方法。

  • 3D并行

    通过将模型、激活值和梯度分割到不同GPU,显存占用可降低至原来的1/3。公式为:
    [ \text{优化后的显存} = \frac{\text{单卡显存}}{N} \times (1 - \frac{\text{模型参数}}{\text{总参数}}) ]

  • 重计算(Recomputation)

    通过丢弃中间数据,显存占用可降低至原来的1/2。公式为:
    [ \text{激活值显存} = s \times b \times h \times a \times \lambda ]

  • Zero方法

    Zero1、Zero2和Zero3通过减少GPU上的参数存储,显存占用可降低至原来的1/8。公式为:
    [ \text{Zero方法显存} = N \times \frac{\text{参数总量}}{L} ]


  • 推理场景

    推理的显存估算公式较为简单,主要包含以下部分: [ \text{总显存} = \text{模型参数} + \text{激活值} + \text{输出数据} ]


    显存优化总结

    显存优化可以通过以下方法实现:

  • 多卡并行:降低单卡显存压力,适合大规模模型训练。
  • 算子优化:选择精度与显存占用平衡的算子。
  • 数据类型优化:采用低精度(如fp16、int8)存储,降低显存占用。
  • 框架优化:消除框架副本,减少内存碎片。
  • 底层API优化:通过优化CUDA函数和显存管理,降低显存消耗。
  • 通过合理搭配以上方法,可以有效降低大模型训练的显存需求,为模型规模的提升提供保障。


    点赞三连

    转载地址:http://olwyz.baihongyu.com/

    你可能感兴趣的文章
    NIFI1.21.0_Mysql到Mysql增量CDC同步中_日期类型_以及null数据同步处理补充---大数据之Nifi工作笔记0057
    查看>>
    NIFI1.21.0_Mysql到Mysql增量CDC同步中_补充_更新时如果目标表中不存在记录就改为插入数据_Postgresql_Hbase也适用---大数据之Nifi工作笔记0059
    查看>>
    NIFI1.21.0_NIFI和hadoop蹦了_200G集群磁盘又满了_Jps看不到进程了_Unable to write in /tmp. Aborting----大数据之Nifi工作笔记0052
    查看>>
    NIFI1.21.0_Postgresql和Mysql同时指定库_指定多表_全量同步到Mysql数据库以及Hbase数据库中---大数据之Nifi工作笔记0060
    查看>>
    NIFI1.21.0最新版本安装_连接phoenix_单机版_Https登录_什么都没改换了最新版本的NIFI可以连接了_气人_实现插入数据到Hbase_实际操作---大数据之Nifi工作笔记0050
    查看>>
    NIFI1.21.0最新版本安装_配置使用HTTP登录_默认是用HTTPS登录的_Https登录需要输入用户名密码_HTTP不需要---大数据之Nifi工作笔记0051
    查看>>
    NIFI1.21.0通过Postgresql11的CDC逻辑复制槽实现_指定表多表增量同步_增删改数据分发及删除数据实时同步_通过分页解决变更记录过大问题_02----大数据之Nifi工作笔记0054
    查看>>
    NIFI1.21.0通过Postgresql11的CDC逻辑复制槽实现_指定表多表增量同步_增加修改实时同步_使用JsonPath及自定义Python脚本_03---大数据之Nifi工作笔记0055
    查看>>
    NIFI1.21.0通过Postgresql11的CDC逻辑复制槽实现_指定表多表增量同步_插入修改删除增量数据实时同步_通过分页解决变更记录过大问题_01----大数据之Nifi工作笔记0053
    查看>>
    NIFI1.21.0通过Postgresql11的CDC逻辑复制槽实现_指定表或全表增量同步_实现指定整库同步_或指定数据表同步配置_04---大数据之Nifi工作笔记0056
    查看>>
    NIFI1.23.2_最新版_性能优化通用_技巧积累_使用NIFI表达式过滤表_随时更新---大数据之Nifi工作笔记0063
    查看>>
    NIFI从MySql中增量同步数据_通过Mysql的binlog功能_实时同步mysql数据_根据binlog实现数据实时delete同步_实际操作04---大数据之Nifi工作笔记0043
    查看>>
    NIFI从MySql中增量同步数据_通过Mysql的binlog功能_实时同步mysql数据_配置binlog_使用处理器抓取binlog数据_实际操作01---大数据之Nifi工作笔记0040
    查看>>
    NIFI从MySql中增量同步数据_通过Mysql的binlog功能_实时同步mysql数据_配置数据路由_实现数据插入数据到目标数据库_实际操作03---大数据之Nifi工作笔记0042
    查看>>
    NIFI从MySql中增量同步数据_通过Mysql的binlog功能_实时同步mysql数据_配置数据路由_生成插入Sql语句_实际操作02---大数据之Nifi工作笔记0041
    查看>>
    NIFI从MySql中离线读取数据再导入到MySql中_03_来吧用NIFI实现_数据分页获取功能---大数据之Nifi工作笔记0038
    查看>>
    NIFI从MySql中离线读取数据再导入到MySql中_不带分页处理_01_QueryDatabaseTable获取数据_原0036---大数据之Nifi工作笔记0064
    查看>>
    NIFI从MySql中离线读取数据再导入到MySql中_无分页功能_02_转换数据_分割数据_提取JSON数据_替换拼接SQL_添加分页---大数据之Nifi工作笔记0037
    查看>>
    NIFI从PostGresql中离线读取数据再导入到MySql中_带有数据分页获取功能_不带分页不能用_NIFI资料太少了---大数据之Nifi工作笔记0039
    查看>>
    nifi使用过程-常见问题-以及入门总结---大数据之Nifi工作笔记0012
    查看>>