解密萬億參數M6模型預訓練背后的分布式框架Whale


簡介: 最近,阿里雲PAI團隊和達摩院智能計算實驗室一起發布“低碳版”巨模型M6,大幅降低萬億參數超大模型訓練能耗。借助我們自研的Whale框架僅使用480卡GPU,即訓練出了規模達人類神經元10倍的萬億參數多模態大模型M6,與傳統海外公司實現萬億參數規模相比,能耗降低超八成、效率提升近11倍。

image.png

作者 | 王林

來源 | 阿里技術公眾號

最近,阿里雲PAI團隊和達摩院智能計算實驗室一起發布“低碳版”巨模型M6,大幅降低萬億參數超大模型訓練能耗。借助我們自研的Whale框架僅使用480卡GPU,即訓練出了規模達人類神經元10倍的萬億參數多模態大模型M6,與傳統海外公司實現萬億參數規模相比,能耗降低超八成、效率提升近11倍。

M6是國內首個實現商業化落地的多模態大模型。M6擁有超越傳統AI的認知和創造能力,擅長繪畫、寫作、問答,在電商、制造業、文學藝術等諸多領域擁有廣泛應用前景。

這里來為大家介紹支持萬億參數模型訓練的Whale框架設計。

一 模型發展趨勢和挑戰

1 模型發展趨勢

隨着深度學習的火爆,模型的參數規模也增長迅速,OpenAI數據顯示:

  • 2012年以前,模型計算耗時每2年增長一倍,和摩爾定律保持一致;
  • 2012年后,模型計算耗時每3.4個月翻一倍,遠超硬件發展速度;

image.png

近一年模型參數規模飛速增長,谷歌、英偉達、阿里、智源研究院都發布了萬億參數模型,有大廠也發布了百億、千億參數模型。同時,隨着模型參數規模增大,模型效果也在逐步提高,Nvidia測試Bert模型不同參數規模,發現模型困惑度隨模型參數規模增加而降低。

image.png

Google在GShard paper中也發現MoETransformer 模型參數規模越大,翻譯質量越高。

image.png

2 大模型訓練的挑戰

大模型帶來模型效果提升的同時,也為訓練框架帶來更大的挑戰,例如當我們要訓練一個萬億規模的模型時會面臨如下挑戰:

  • 訓練難:

    • GPU顯存已經不夠存放模型副本,數據並行已經不能滿足需求;
    • 需要框架提供新的並行策略,協同多GPU能力來存放和訓練模型;
    • 如何給用戶提供簡潔、易用的接口,讓用戶能很容易實現分布式版模型;
    • 超大規模模型對計算效率、通信效率都帶來很大挑戰,如何提高計算和通信效率;
    • 下游任務如何對接,如何支持批量預測和在線推理需求;
  • 成本高:

    • 以萬億模型為例,模型參數有4TB大小、梯度也有4TB,加上optimizer states和active tensor,顯存需求巨大;
    • 業界訓練同等規模模型需要的資源:英偉達 3072 A100、谷歌 2048 TPU v3,成本太高很難落地;
    • 如何降本增效,使用更少的資源,更快的訓練收斂;

當前已經有一些分布式訓練框架,例如:Horovod、Tensorflow Estimator、PyTorch DDP等支持數據並行,Gpipe、PipeDream、PipeMare等支持流水並行,Mesh Tensorflow、FlexFlow、OneFlow、MindSpore等支持算子拆分,但這些框架還有一些不足:

  • 模式單一:很多框架只支持部分並行策略,不能完全支持各種混合並行;
  • 接入門檻高:用戶實現模型分布式版本難度大、成本高,需要有領域專家經驗才能實現高效的分布式並行策略;
  • 遷移代價大:不同分布式框架並行化實現割裂,不同框架有各自定義的DSL,當用戶要切換並行策略時,需要學習各種接口,重新改寫模型;
  • 性能不理想:部分框架實現未考慮集群物理環境;

為了應對當前分布式訓練的挑戰,我們研發了分布式訓練框架Whale,主要目標是:

  • 統一多種並行策略:在一個框架中支持各種並行策略以及這些策略的各種組合;
  • 簡潔易用的接口:用戶只需添加幾行annotation即可完成並行策略的配置,模型代碼不需要改動;
  • 高效的訓練框架:結合硬件資源、網絡拓撲和模型進行協同優化,打造高效分布式訓練框架;

二 PAI自研Whale框架

1 Whale架構

我們推出統一多種並行策略的高性能分布式訓練框架Whale,從如下角度來應對分布式訓練的挑戰:

  • 將不同並行化策略進行統一抽象、封裝,在一套分布式訓練框架中支持多種並行策略;
  • 基於Tensorflow設計一套分布式並行接口,完全兼容Tensorflow,用戶僅僅只需添加幾行annotation就可以實現豐富的分布式並行策略;
  • 結合模型結構和網絡拓撲進行調度和通信優化,提供高效的分布式訓練能力。

Whale框架如下圖所示,主要分4個模塊:

  • API:提供簡潔易用接口,讓用戶組合使用各種混合並行策略;
  • Whale IR:將並行策略轉成內部表達,通過TaskGraph、Multi-Dimension、VirtualDevices抽象來表達各種並行策略;
  • Whale Engine:基於WhaleIR,通過圖編輯工具來構建分布式執行圖;
  • Runtime:將分布式執行圖轉成TFGraph,再調用TF 的Runtime來執行;

image.png

2 Whale簡介易用接口

Whale提供簡潔易用的接口來描述各種並行策略,主要的原語:

  • cluster:配置Virtual Device的划分方法
  • replica:數據並行
  • stage:划分TaskGraph
  • pipeline:流水並行
  • split:算子拆分

用這些接口可以組合各種並行策略,例如:

  • 數據並行:

image.png

  • 流水並行:

image.png

  • 流水並行+數據並行:

image.png

  • 更多並行策略示例:

image.png

3 Whale訓練流程

使用Whale進行分布式訓練流程:

  • 並行策略配置:

    • 使用Whale API來為模型配置並行策略,只需添加幾行annotation,無需修改模型代碼,方法如 2.2節 所示;
    • 可以將模型划分為多個TaskGraph,TaskGraph支持配置多個並行策略,每個TaskGraph可以配置不同的並行策略;
  • 虛擬資源划分:

    • 按並行策略來划分Virtual Device,每個TaskGraph對應一個Virtual Device;
    • 按GPU資源和網絡topo來為Virtual Device選擇Physical Device;
  • 分布式執行圖:

    • 基於並行策略和資源分配信息,使用圖編輯工具來編輯執行圖(圖拷貝、拆分、插入通信節點等),生成最終的分布式執行圖;
    • 調用TF的runtime來執行分布式Graph;

image.png

三 萬億M6模型預訓練

萬億模型的算力需求非常大,為了降低算力需求,Whale中實現了MoE(Mixture-of-Experts)結構,MoE的主要特點是稀疏激活,使用Gating(Router)來為輸入選擇Top k的expert進行計算(k常用取值1、2),從而大大減少算力需求。

image.png

Whale中實現了MoE(Mixture-of-Experts) layer,並支持專家並行,將experts拆分到多個Devices上,降低單個Device的顯存和算力需求。同時數據並行有利於提升訓練的並發度,因此采用數據並行+專家並行組合的混合並行策略來訓練M6模型:MoElayer采用專家並行,其他layer采用數據並行。

image.png

Whale中提供簡潔易用的接口來進行模型的混合並行訓練,只需要增加幾行annotation來配置並行策略,模型本身不需要任何修改。M6模型采用數據並行+專家並行的策略,只需要增加如下圖的annotation:

image.png

同時為了節約訓練資源,提高訓練效率,Whale中提供各種優化技術:

顯存優化:

  • Auto Gradient Checkpoint,自動選擇最優checkpoint節點,節約activation的顯存;
  • Group-wise Apply,優化Optimizer Apply階段的顯存;
  • CPU Offload技術,優化Optimizer status和Weight的顯存;
  • 通信池化,控制通信的數據塊大小和並發,節約通信的顯存;

計算、通信加速:

  • 采用DP+EP混合並行策略,降低算力需求;
  • 采用分組融合通信、半精度通信、拓撲感知的All2All通信算子等技術來提高通信效率;
  • 結合混合精度、編譯優化等技術提高訓練效率;

借助Whale框架,首次在480 V100 上,3天內完成萬億M6模型的預訓練。相比此前英偉達使用3072 A100 GPU實現萬億參數、谷歌使用2048 TPU實現1.6萬億參數大模型,此次達摩院僅使用480卡V100 32G GPU就實現了萬億模型M6,節省算力資源超80%,且訓練效率提升近11倍。

四 結語

模型參數規模已越來越大,大模型已成為發展趨勢,為解決超大模型訓練的挑戰,我們自研Whale框架,將不同並行化策略進行統一抽象、封裝,在一套分布式訓練框架中支持多種並行策略。Whale提供簡潔易用的接口,用戶只需添加幾行annotation即可實現各種並行策略,不需要對模型本身進行修改。同時我們結合硬件資源、網絡topo、模型進行軟硬件協同優化,提供高效分布式訓練框架。

通過Whale框架,我們用480 V100 GPU卡訓練萬億規模模型,並在3天內完成模型訓練收斂,為超大規模模型訓練落地提供了可能,后續我們會進一步完善Whale框架,從更大規模、更快速度、更高性價比3個維度去擴展Whale框架的能力。同時也會推動Whale能力在更多業務場景落地,讓技術能力到產品能力的轉變。

原文鏈接
本文為阿里雲原創內容,未經允許不得轉載。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM