在早先的一篇科学博客中,曾提出MiCS方法,该方法显著提升了参数规模达1750亿的机器学习模型的训练效率。然而,为了在新任务中实现可靠的少样本学习,自然语言处理模型仍需向万亿参数规模扩展。本文介绍对MiCS的两项新增强,使某中心的客户能够训练和微调万亿参数规模的模型:(1)连续参数管理;(2)预取激活卸载。
上图展示了在四设备集群上,采用ZeRO-3数据并行策略的双层深度学习神经网络模型在前向和反向传播中的参数收集过程。在前向计算开始前,每个工作节点(rank)仅持有部分模型参数。为计算第一层的激活值,需使用全收集(all-gather)操作聚合其参数。
获得第一层输出后,立即对参数进行分区以释放内存,并继续处理下一神经网络层。在计算梯度时,这两个步骤以相反顺序重复执行。反复的全收集和分区过程导致集体通信频繁使用,这在Pytorch中引发严重的内存碎片化和缓存刷新问题。为解决此问题,预分配连续参数缓冲区以容纳收集后的完整参数张量,并自主管理张量生命周期和碎片整理,而不影响Pytorch内存分配器的行为。实践表明,该方法显著提升了内存受限任务的性能。
此外,还开发了预取激活卸载技术以进一步节省GPU内存,该技术与激活检查点结合使用。每个检查点激活被卸载至CPU内存,并在反向传播期间通过CUDA并行计算平台开启的专用流进行预取。由于数据传输是异步的,使用预取激活卸载仅导致约1-2%的速度损失。
某中心近期发布了由400 Gbps网络和80GB GPU内存驱动的P4de GPU实例预览。P4de提供的GPU内存是P4d的两倍,使得更少节点即可在GPU内存中容纳大型模型,从而降低通信开销。新硬件使我们能够更高效地扩展MiCS以支持更大模型。
实验结果表明,在公有云的64个P4de实例上训练210层1.06万亿参数模型时,实现了每GPU 176万亿次浮点运算的最高性能(达到理论峰值的56.4%)。该设置中,模型隐藏层大小为20,480,词汇表大小为50,264。使用序列长度1,024,每GPU批大小为8,前向和反向传播采用bfloat16精度,优化器为float32 Adam。
MiCS预览版现已在某中心Pytorch DLC和SageMaker中以分片数据并行形式提供。通过利用新技术,某中心客户可用仅需本地DGX-A100集群四分之一网络带宽突破GPU内存限制,实现万亿参数模型训练。
致谢:Yida Wang, RJ