MLIR與Code Generation
MLIR多級中間表示
MLIR 項目是一種構建可重用和可擴展編譯器基礎架構的新方法。MLIR 旨在解決軟件碎片問題,改進異構硬件的編譯,顯着降低構建特定領域編譯器的成本,幫助將現有編譯器連接在一起。
MLIR作用
MLIR 旨在成為一種混合 IR,可以在統一的基礎架構中,支持多種不同的需求。例如,包括:
• 表示數據流圖的能力(例如在 TensorFlow 中),包括動態shape,user-extensible用戶可擴展的操作生態系統,TensorFlow 變量等。
• 優化和轉換通常在此類圖上完成(例如在 Grappler 中)。
• 能夠跨內核(融合,循環交換,平鋪等)托管高性能計算風格的循環優化,轉換數據的內存布局。
• 代碼生成“降低”轉換,例如 DMA 插入,顯式緩存管理,內存平鋪,一維和二維寄存器架構的矢量化。
• 能夠表示特定於目標的操作,例如特定於加速器的高級操作。
• 在深度學習圖上完成的量化和其它圖轉換。
• Polyhedral primitives。
• Hardware Synthesis Tools / HLS。
MLIR 是一種常見的 IR,也支持硬件特定的操作。因此,對圍繞 MLIR 的基礎架構的任何投資(例如,對其進行工作的編譯器通過),都應該產生良好的回報;許多目標可以使用基礎架構受益。
MLIR 是一個強大的表示,但也有非目標。不嘗試支持低級機器代碼生成算法(如寄存器分配和指令調度)。更適合較低級別的優化器(例如 LLVM)。此外,不希望 MLIR 成為最終用戶編寫內核的源語言(類似於 CUDA C++)。另一方面,MLIR 提供了代表任何此類 DSL集成到生態系統中的主干。
編譯器基礎架構
在構建 MLIR 時,從構建其他 IR(LLVM IR,XLA HLO 和 Swift SIL)中獲得的經驗中受益。MLIR 框架鼓勵現有的最佳實踐,例如編寫和維護 IR 規范,構建 IR 驗證器,提供將 MLIR 文件轉儲和解析為文本的能力,使用FileCheck 工具編寫大量單元測試,將基礎架構構建為一組可以以新方式組合的模塊化庫。
LLVM 有一些不明顯的設計錯誤,阻止多線程編譯器,處理 LLVM 模塊中的多個函數。MLIR 通過限制 SSA 范圍,減少 use-def 鏈,用顯式替換cross-function引用,解決這些問題 symbol reference。
代碼生成(Code Generation)
代碼生成(Code Generation)技術廣泛應用於現代的數據系統中。代碼生成是將用戶輸入的表達式,查詢,存儲過程等現場編譯成二進制代碼再執行,相比解釋執行的方式,運行效率要高得多。尤其是對於計算密集型查詢,或頻繁重復使用的計算過程,運用代碼生成技術能達到數十倍的性能提升。
很多大數據產品都將代碼生成技術作為賣點,然而事實上,往往談論的不是一件事情。比如,之前就有人提問:Spark 1.x 就已經有代碼生成技術,為什么 Spark 2.0 又把代碼生成吹了一番?其中的原因在於,雖然都是代碼生成,但是各個產品生成代碼的粒度是不同的:
o 最簡單的,例如 Spark 1.4,使用代碼生成技術加速表達式計算;
o Spark 2.0 支持將同一個 Stage 的多個算子組合編譯成一段二進制;
o 更有甚者,支持將自定義函數,存儲過程等編譯成一段二進制,例如 SQL Server。
本文主要講上面最簡單的表達式編譯。通過一個簡單的例子,初步了解代碼生成的流程。
解析執行的缺陷
在講代碼生成之前,回顧一下解釋執行。以上面圖中的表達式 X×5+log(10)X×5+log(10) 為例,計算過程是一個深度優先搜索(DFS)的過程:
1. 調用根節點 + 的 visit() 函數:分別調用左,右子節點的 visit() 再相加;
2. 調用乘法節點 * 的 visit() 函數:分別調用左,右子節點的 visit() 再相乘;
3. 調用變量節點 X 的 visit() 函數:從環境中讀取 XX 的值以及類型。
(……略)最終,DFS 回到根節點,得到最終結果。
@Override public Object visitPlus(CalculatorParser.PlusContext ctx) {
Object left = visit(ctx.plusOrMinus());
Object right = visit(ctx.multOrDiv());
if (left instanceof Long && right instanceof Long) {
return (Long) left + (Long) right;
} else if (left instanceof Long && right instanceof Double) {
return (Long) left + (Double) right;
} else if (left instanceof Double && right instanceof Long) {
return (Double) left + (Long) right;
} else if (left instanceof Double && right instanceof Double) {
return (Double) left + (Double) right;
}
throw new IllegalArgumentException();
}
上述過程中,有幾個顯而易見的性能問題:
o 涉及到大量的虛函數調用,即函數綁定的過程,例如 visit() 函數,虛函數調用是一個非確定性的跳轉指令, CPU 無法做預測分支,從而導致打斷 CPU 流水線;
o 在計算之前不能確定類型,因而各個算子的實現中會出現很多動態類型判斷,例如:如果 + 左邊是 DECIMAL 類型,而右邊是 DOUBLE,需要先把左邊轉換成 DOUBLE 再相加;
o 遞歸中的函數調用打斷了計算過程,不僅調用本身需要額外的指令,而且函數調用傳參是通過棧完成的,不能很好的利用寄存器(這一點在現代的編譯器和硬件體系中,已經有所緩解,但顯然比不上連續的計算指令)。
代碼生成基本過程
代碼生成執行,顧名思義,最核心的部分是生成出需要的執行代碼。
拜編譯器所賜,並不需要寫難懂的匯編或字節碼。在 native 程序中,通常用 LLVM 的中間語言(IR)作為生成代碼的語言。而 JVM 上更簡單,因為 Java 編譯本身很快,利用運行在 JVM 上的輕量級編譯器 janino,可以直接生成 Java 代碼。
無論是 LLVM IR 還是 Java 都是靜態類型的語言,在生成的代碼中再去判斷類型顯然不是個明智的選擇。通常的做法是在編譯之前就確定所有值的類型。幸運的是,表達式和 SQL 執行計划都可以事先做類型推導。
所以,綜上所述,代碼生成往往是個 2-pass 的過程:先做類型推導,再做真正的代碼生成。第一步中,類型推導的同時其實也是在檢查表達式是否合法,因此很多地方也稱之為驗證(Validate)。
在代碼生成完成后,調用編譯器編譯,得到了所需的函數(類),調用它即可得到計算結果。如果函數包含參數,例如上面例子中的 X,每次計算可以傳入不同的參數,編譯一次,計算多次。
以下的代碼實現都可以在 GitHub 項目 fuyufjh/calculator 找到。
驗證(Validate)
為了盡可能簡單,例子中僅涉及兩種類型:Long 和 Double
這一步中,將合法的表達式 AST 轉換成 Algebra Node,這是一個遞歸語法樹的過程,下面是一個例子(由於 Plus 接收 Long/Double 的任意類型組合,所以此處沒有做類型檢查):
@Override public AlgebraNode visitPlus(CalculatorParser.PlusContext ctx) {
return new PlusNode(visit(ctx.plusOrMinus()), visit(ctx.multOrDiv()));
}
AlgebraNode 接口定義如下:
public interface AlgebraNode {
DataType getType(); // Validate 和 CodeGen 都會用到
String generateCode(); // CodeGen 使用
List<AlgebraNode> getInputs();
}
實現類大致與 AST 的中的節點相對應,如下圖。
對於加法,類型推導的過程很簡單——如果兩個操作數都是 Long 則結果為 Long,否則為 Double。
@Override public DataType getType() {
if (dataType == null) {
dataType = inferTypeFromInputs();
}
return dataType;
}
private DataType inferTypeFromInputs() {
for (AlgebraNode input : getInputs()) {
if (input.getType() == DataType.DOUBLE) {
return DataType.DOUBLE;
}
}
return DataType.LONG;
}
生成代碼
依舊以加法為例,利用上面實現的 getType(),可以確定輸入,輸出的類型,生成出強類型的代碼:
@Override public String generateCode() {
if (getLeft().getType() == DataType.DOUBLE && getRight().getType() == DataType.DOUBLE) {
return "(" + getLeft().generateCode() + " + " + getRight().generateCode() + ")";
} else if (getLeft().getType() == DataType.DOUBLE && getRight().getType() == DataType.LONG) {
return "(" + getLeft().generateCode() + " + (double)" + getRight().generateCode() + ")";
} else if (getLeft().getType() == DataType.LONG && getRight().getType() == DataType.DOUBLE) {
return "((double)" + getLeft().generateCode() + " + " + getRight().generateCode() + ")";
} else if (getLeft().getType() == DataType.LONG && getRight().getType() == DataType.LONG) {
return "(" + getLeft().generateCode() + " + " + getRight().generateCode() + ")";
}
throw new IllegalStateException();
}
注意,目前代碼還是以 String 形式存在的,遞歸調用的過程中通過字符串拼接,一步步拼成完整的表達式函數。
以表達式 a + 2*3 - 2/x + log(x+1) 為例,最終生成的代碼如下:
1 (((double)(a + (2 * 3)) - ((double)2 / x)) + java.lang.Math.log((x + (double)1)))
其中,a,x 都是未知數,但類型是已經確定的,分別是 Long 型和 Double 型。
編譯器編譯
Janino 是一個流行的輕量級 Java 編譯器,與常用的 javac 相比它最大的優勢是:可以在 JVM 上直接調用,直接在進程內存中運行編譯,速度很快。
上述代碼僅僅是一個表達式,並不是完整的 Java 代碼,但 janino 提供了方便的 API 能直接編譯表達式:
ExpressionEvaluator evaluator = new ExpressionEvaluator();
evaluator.setParameters(parameterNames, parameterTypes); // 輸入參數名及類型
evaluator.setExpressionType(rootNode.getType() == DataType.DOUBLE ? double.class : long.class); // 輸出類型
evaluator.cook(code); // 編譯代碼
實際上,你也可以手工拼接出如下的類代碼,交給 janino 編譯,效果是完全相同的:
class MyGeneratedClass {
public double calculate(long a, double x) {
return (((double)(a + (2 * 3)) - ((double)2 / x)) + java.lang.Math.log((x + (double)1)));
}
}
最后,依次輸入所有參數即可調用剛剛編譯的函數:
Object result = evaluator.evaluate(parameterValues);
References
o Apache Spark - GitHub
o Janino by janino-compiler
o fuyufjh/calculator: A simple calculator to demonstrate code gen technology
參考鏈接:
https://mlir.llvm.org/
https://ericfu.me/code-gen-of-expression/