PMML簡介
預測模型標記語言PMML(Predictive Model Markup Language)是一套與平台和環境無關的模型表示語言,是目前表示機器學習模型的實際標准。
作為一個開放的成熟標准,PMML由數據挖掘組織DMG(Data Mining Group)開發和維護,經過十幾年的發展,得到了廣泛的應用,有超過30家廠商和開源項目(包括SAS,IBM SPSS,KNIME,RapidMiner等主流廠商)在它們的數據挖掘分析產品中支持並應用PMML,這些廠商應用詳情見下表:PMML Powered
PMML標准介紹
PMML是一套基於XML的標准,通過 XML Schema 定義了使用的元素和屬性,主要由以下核心部分組成:
數據字典(Data Dictionary),描述輸入數據。
數據轉換(Transformation Dictionary和Local Transformations),應用在輸入數據字段上生成新的派生字段。
模型定義 (Model),每種模型類型有自己的定義。
輸出(Output),指定模型輸出結果。
PMML預測過程符合數據挖掘分析流程:
pmml-flow.png
PMML優點
平台無關性。PMML可以讓模型部署環境脫離開發環境,實現跨平台部署,是PMML區別於其他模型部署方法最大的優點。比如使用Python建立的模型,導出PMML后可以部署在Java生產環境中。
互操作性。這就是標准協議的最大優勢,實現了兼容PMML的預測程序可以讀取其他應用導出的標准PMML模型。
廣泛支持性。已取得30余家廠商和開源項目的支持,通過已有的多個開源庫,很多重量級流行的開源數據挖掘模型都可以轉換成PMML。
可讀性。PMML模型是一個基於XML的文本文件,使用任意的文本編輯器就可以打開並查看文件內容,比二進制序列化文件更安全可靠。
PMML開源類庫
模型轉換庫,生成PMML:
Python模型:
Nyoka,支持Scikit-Learn,LightGBM,XGBoost,Statsmodels和Keras。https://github.com/nyoka-pmml/nyoka
JPMML系列,比如JPMML-SkLearn、JPMML-XGBoost、JPMML-LightGBM等,提供命令行程序導出模型到PMML。https://github.com/jpmml
R模型:
R pmml包:https://cran.r-project.org/web/packages/pmml/index.html
r2pmml:https://github.com/jpmml/r2pmml
JPMML-R:提供命令行程序導出R模型到PMML,https://github.com/jpmml/jpmml-r
Spark:
Spark mllib,但是只是模型本身,不支持Pipelines,不推薦使用。
JPMML-SparkML,支持Spark ML pipleines。https://github.com/jpmml/jpmml-sparkml
模型評估庫,讀取PMML:
Java:
JPMML-Evaluator,純Java的PMML預測庫,開源協議是AGPL V3。https://github.com/jpmml/jpmml-evaluator
PMML4S,使用Scala開發,同時提供Scala和Java API,接口簡單好用,開源協議是常用的寬松協議Apache 2。https://github.com/autodeployai/pmml4s
Python:
PyPMML,PMML的Python預測庫,PyPMML是PMML4S包裝的Python接口。https://github.com/autodeployai/pypmml
Spark:
JPMML-Evaluator-Spark,https://github.com/jpmml/jpmml-evaluator-spark
PMML4S-Spark,https://github.com/autodeployai/pmml4s-spark
PySpark:
PyPMML-Spark,PySpark中預測PMML模型。https://github.com/autodeployai/pypmml-spark
REST API:
AI-Serving,同時為PMML模型提供REST和gRPC API,開源協議Apache 2。https://github.com/autodeployai/ai-serving
Openscoring,提供REST API,開源協議AGPL V3。https://github.com/openscoring/openscoring
sparkml訓練完模型后保存模型為PMML文件:
model.toPMML(spark.sparkContext, "G:\pmml\spark\lr\xml")
java 使用pmml4s加載pmml文件示例:
pmml文件:
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML version="4.2" xmlns="http://www.dmg.org/PMML-4_2">
<Header description="logistic regression">
<Application name="Apache Spark MLlib" version="2.4.0"/>
<Timestamp>2020-11-25T19:52:36</Timestamp>
</Header>
<DataDictionary numberOfFields="5">
<DataField name="field_0" optype="continuous" dataType="double"/>
<DataField name="field_1" optype="continuous" dataType="double"/>
<DataField name="field_2" optype="continuous" dataType="double"/>
<DataField name="field_3" optype="continuous" dataType="double"/>
<DataField name="target" optype="categorical" dataType="string"/>
</DataDictionary>
<RegressionModel modelName="logistic regression" functionName="classification" normalizationMethod="logit">
<MiningSchema>
<MiningField name="field_0" usageType="active"/>
<MiningField name="field_1" usageType="active"/>
<MiningField name="field_2" usageType="active"/>
<MiningField name="field_3" usageType="active"/>
<MiningField name="target" usageType="target"/>
</MiningSchema>
<RegressionTable intercept="0.0" targetCategory="1">
<NumericPredictor name="field_0" coefficient="-5.95759188680503"/>
<NumericPredictor name="field_1" coefficient="-1.6974588567364868"/>
<NumericPredictor name="field_2" coefficient="-5.660350922982105"/>
<NumericPredictor name="field_3" coefficient="8.680992926976252"/>
</RegressionTable>
<RegressionTable intercept="-0.0" targetCategory="0"/>
</RegressionModel>
</PMML>
pom文件添加依賴:
<dependency>
<groupId>org.pmml4s</groupId>
<artifactId>pmml4s_2.12</artifactId>
<version>0.9.7</version>
</dependency>
java代碼:
import org.pmml4s.model.Model;
import java.util.HashMap;
import java.util.Map;
public class PMML4SDemo {
public static void main(String[] args) {
Model model = Model.fromFile("G:\\pmml\\spark\\lr\\xml\\lr.xml");
Map<String, Object> result = model.predict(new HashMap<String, Object>() {{
put("field_0", 2);
put("field_1", 4);
put("field_2", 1);
put("field_3", 5);
}});
System.out.println(result);
System.out.println(result.get("predicted_target"));
}
}