一文帶你從零認識什么是XLA


摘要:簡要介紹XLA的工作原理以及它在 Pytorch下的使用。

本文分享自華為雲社區《XLA優化原理簡介》,作者: 拓荒者01。

初識XLA

XLA的全稱是Accelerated Linear Algebra,即加速線性代數。作為一種深度學習編譯器,長期以來被作為Tensorflow框架的一個試驗特性被開發,歷時至今已經超過兩三年了,隨着Tensorflow 2.X的發布,XLA也終於從試驗特性變成了默認打開的特性。此外, Pytorch社區也在大力推動XLA在Pytorch下的開發,現在已經有推出PyTorch/XLA TPU版本,暫只支持谷歌平台TPU上使用。

LLVM

提到編譯器就不得不提大名鼎鼎的LLVM。LLVM是一個編譯器框架,由C++語言編寫而成,包括一系列分模塊、可重用的編譯工具。

LLVM框架的主要組成部分有:

1、前端:負責將源代碼轉換為一種中間表示

2、優化器:負責優化中間代碼

3、后端:生成可執行機器碼的模塊

LLVM為不同的語言提供了同一種中間表示LLVM IR,這樣子如果我們需要開發一種新的語言的時候,我們只需要實現對應的前端模塊,如果我們想要支持一種新的硬件,我們只需要實現對應的后端模塊,其他部分可以復用。

XLA編譯

XLA也是基於LLVM框架開發的,前端的輸入是Graph,前端沒有將Graph直接轉化為LLVM IR。首先XLA的功能主要體現在兩個方面:

1、即時編譯(Just-in-time)

2、超前編譯(Aheda-of-time)

無論是哪個功能,都是服務於以下目的:

1、提高代碼執行速度

2、優化存儲使用

此外,XLA還有着大部分深度學習編譯器都有的夢想:擺脫計算庫的限制,自動生成算子代碼並支持在多硬件上的良好可移植性。

作為編譯器,XLA負責對前端定義的計算圖進行優化。如上圖所示,XLA的優化流程可以分成兩方面,目標無關優化和目標相關優化。在優化步驟之間傳遞的是計算圖的中間表示形式,HLO,即High Level Optimizer(高級優化器) ,XLA用這種中間表示形式表示正在被優化的計算圖,其有自己的文法和語義,這里不做詳細介紹

XLA優勢

  • 編譯子計算圖以減少短暫運算的執行時間,從而消除運行時的開銷;融合流水線運算以降低內存開銷;並針對已知張量形狀執行專門優化以支持更積極的常量傳播。
  • 提高內存使用率: 分析和安排內存使用,消除了許多中間存儲緩沖區。
  • 降低對自定義運算的依賴:通過提高自動融合的低級運算的性能,使之達到手動融合的自定義運算的性能水平,從而消除對多種自定義運算的需求。
  • 提高便攜性:使針對新穎硬件編寫新后端的工作變得相對容易,在新硬件上運行時,大部分程序都能夠以未經修改的方式運行。與針對新硬件專門設計各個整體運算的方式相比,這種模式不必重新編寫 程序即可有效利用這些運算。

XLA工作原理

我們先來看XLA如何作用於計算圖,下面是一張簡單的計算圖

這里我們假設XLA僅支持matmul和add。XLA通過圖優化方法,在計算圖中找到適合被JIT編譯的區域

XLA把這個區域定義為一個Cluster,作為一個獨立的JIT編譯單元,計算圖中通過Node Attribute標示

然后另一個的圖優化方法,把cluster轉化成TensorFlow的一個Function子圖。在原圖上用一個Caller節點表示這個Function在原圖的位置

最后調用TensorFlow的圖優化方法(BuildXlaOps),把Function節點轉化成特殊的Xla節點。

在TensorFlow運行時,運行到XlaCompile時,編譯Xla cluster子圖,然后把編譯完的Executable可執行文件通過XlaExecutableClosure傳給XlaRun運行。

接着根據虛擬指令分配GPU Stream和顯存,然后IrEmitter把HLO Graph轉化成由編譯器的中間表達LLVM IR表示的GPU Kernel。最后由LLVM生成nvPTX(Nvidia定義的虛擬底層指令表達形式)表達,進而由NVCC生成CuBin可執行代碼。

AOT和JIT

JIT,動態(即時)編譯,邊運行邊編譯;AOT,指運行前編譯。這兩種編譯方式的主要區別在於是否在“運行時”進行編譯,對於AI訓練模型中,AOT模式下更具有性能優勢,具體流程如下圖:

對於大部分AI模型來說,訓練過程一般情況下圖是不會怎么變的,所以在訓這樣子就在執行過程中省略練的時候使用AOT模式能大大提高訓練的速度

Pytorch/XLA

創建 XLA 張量:PyTorch/XLA 為 PyTorch 添加了新的 xla 設備類型。 此設備類型的工作方式與普通 PyTorch 設備類型一樣。 例如,以下是創建和打印 XLA 張量的方法:

這段代碼應該看起來很熟悉。 PyTorch/XLA 使用與常規 PyTorch 相同的界面,但添加了一些內容。 導入 torch_xla 初始化 PyTorch/XLA,xm.xla_device() 返回當前的 XLA 設備。 這可能是 CPU 或 GPU,具體取決於您的環境。

XLA 張量是 PyTorch 張量:PyTorch 操作可以在 XLA 張量上執行,就像 CPU 或 CUDA 張量一樣。例如,XLA 張量可以相加:

XLA 設備上運行模型:構建新的 PyTorch 網絡或將現有網絡轉換為在 XLA 設備上運行只需要幾行特定於 XLA 的代碼,現階段官方只支持JIT模式。 以圖是在官方版本單個XLA設備上運行時代碼段

這段代碼可以看出切換model在 XLA 上運行是多么容易。 model定義、數據加載器、優化器和訓練循環可以在任何設備上工作。 唯一的 特別代碼是獲取 XLA device和mark step的幾行代碼。因為XLA tensor運行是lazy( 懶惰的)。 所以只在圖形中記錄操作,直到需要結果為止,調用 xm.mark_step() 才會執行其當前圖獲取運行結果並更新模型的參數。

 

點擊關注,第一時間了解華為雲新鮮技術~


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM