在早前的一篇科学博客中,我们提出了MiCS方法,该方法显著提升了高达1750亿参数机器学习模型的训练效率。但随着自然语言处理模型向万亿参数规模扩展的需求日益增长,本文介绍了两项对MiCS的新增强功能,使某中心客户能够训练和微调万亿参数规模的模型:(1)连续参数管理;(2)预取激活卸载。
上图展示了在四设备集群上使用ZeRO-3数据并行进行前向和后向传播的过程。在开始前向传播前,每个工作节点仅持有部分模型参数。为计算第一层的激活值,我们使用全收集操作来收集其参数。
获得第一层输出后,我们立即对其参数进行分区以释放内存,并继续处理下一神经网络层。在计算梯度时,这两个步骤会以相反顺序重复执行。重复的全收集和分区过程导致集体通信频繁使用,这会在PyTorch中引发严重的内存碎片化和缓存刷新问题。为解决此问题,我们预分配连续参数缓冲区来保存收集后的完整参数张量,并自行管理张量生命周期和碎片整理,而不影响PyTorch内存分配器的行为。实践表明,该方法显著提升了内存受限任务的性能。
此外,我们还开发了预取激活卸载技术,通过与激活检查点结合使用进一步节省GPU内存。每个检查点激活会被卸载至CPU内存,并在反向传播期间通过CUDA并行计算平台开启的专用流进行预取。由于数据传输是异步的,使用预取激活卸载仅导致约1-2%的速度损失。
某中心近期发布了搭载400 Gbps网络和80GB GPU内存的EC2 P4de GPU实例预览版。P4de提供的GPU内存是P4d的两倍,使得更少节点即可在GPU内存中容纳大型模型,从而降低通信开销。新硬件使我们能更高效地利用MiCS扩展至更大模型。
实验结果显示,在公有云的64个P4de实例上训练210层1.06万亿参数模型时,我们实现了每GPU 176万亿次浮点运算的性能(达到理论峰值的56.4%)。该模型隐藏层大小为20,480,词表大小为50,264,使用序列长度1,024,每GPU批大小8,前向后向传播采用bfloat16精度,并搭配float32 Adam优化器。
MiCS预览版现已通过某中心PyTorch DLC和SageMaker中的分片数据并行功能提供。通过利用这些新技术,某中心客户可突破GPU内存限制,仅需本地DGX-A100集群四分之一网络带宽即可实现万亿参数模型训练。