記錄下自己的過程,以后可以隨時用,如果能幫到大家就更好了。
從安裝軟件說起,嫌麻煩的就別看了。
一、下載工具(俗話說得好,預先善其事必先利其器!哈哈)
我剛開始安裝的是eclipse,但有諸多麻煩不能解決,就用了IDEA,和Pycharm一個公司發行的。
首先進入官網: http://www.jetbrains.com/products.html#lang=java
選擇IDEA下載:
由於社區版的功能太少,我下載的是企業版的,后邊會告訴破解方法。
IDEA的安裝教程網上都有,正常安裝就好。
企業版的激活碼大家可以關注一個公眾號,我也是在網上找到的。
http://idea.medeming.com/
關注公眾號后粘貼就行了。
二、Java環境安裝
參考教程:https://blog.csdn.net/weixin_38381149/article/details/89668578
寫博客時想找當時看的博客,但發現了這個很全的,jdk,maven,tomcat都有。
想當初我為了裝一個maven花了好久。。。
三、新建Maven項目
File ==》New==》Project==》Maven
四、接下來在IDEA中配置Maven,這是當時參考的博客:https://www.cnblogs.com/jiangzhaowei/p/9534393.html
五、添加依賴
由於我只是為了調用模型,沒有太多依賴,只添加了這么幾個
<dependencies> <dependency> <groupId>org.jpmml</groupId> <artifactId>pmml-evaluator</artifactId> <version>1.4.1</version> </dependency> <dependency> <groupId>org.jpmml</groupId> <artifactId>pmml-evaluator-extension</artifactId> <version>1.4.1</version> </dependency> <dependency> <groupId>javax.xml.bind</groupId> <artifactId>jaxb-api</artifactId> <version>2.3.0</version> </dependency> <dependency> <groupId>com.sun.xml.bind</groupId> <artifactId>jaxb-core</artifactId> <version>2.3.0</version> </dependency> <dependency> <groupId>com.sun.xml.bind</groupId> <artifactId>jaxb-impl</artifactId> <version>2.3.0</version> </dependency> </dependencies>
六、java調用Python訓練出的pmml模型的代碼
import org.dmg.pmml.FieldName; import org.dmg.pmml.PMML; import org.jpmml.evaluator.*; import org.jpmml.model.PMMLUtil; import org.xml.sax.SAXException; import javax.xml.bind.JAXBException; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; public class ClassificationModel { private Evaluator modelEvaluator; /** * 通過傳入 PMML 文件路徑來生成機器學習模型 * * @param pmmlFileName pmml 文件路徑 */ public ClassificationModel(String pmmlFileName) { PMML pmml = null; try { if (pmmlFileName != null) { InputStream is = new FileInputStream(pmmlFileName); pmml = PMMLUtil.unmarshal(is); try { is.close(); } catch (IOException e) { System.out.println("InputStream close error!"); } ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance(); this.modelEvaluator = (Evaluator) modelEvaluatorFactory.newModelEvaluator(pmml); modelEvaluator.verify(); System.out.println("加載模型成功!"); } } catch (SAXException e) { e.printStackTrace(); } catch (JAXBException e) { e.printStackTrace(); } catch (FileNotFoundException e) { e.printStackTrace(); } } // 獲取模型需要的特征名稱 public List<String> getFeatureNames() { List<String> featureNames = new ArrayList<String>(); List<InputField> inputFields = modelEvaluator.getInputFields(); for (InputField inputField : inputFields) { featureNames.add(inputField.getName().toString()); } return featureNames; } // 獲取目標字段名稱 public String getTargetName() { return modelEvaluator.getTargetFields().get(0).getName().toString(); } // 使用模型生成概率分布 private ProbabilityDistribution getProbabilityDistribution(Map<FieldName, ?> arguments) { Map<FieldName, ?> evaluateResult = modelEvaluator.evaluate(arguments); FieldName fieldName = new FieldName(getTargetName()); return (ProbabilityDistribution) evaluateResult.get(fieldName); } // 預測不同分類的概率 public ValueMap<String, Number> predictProba(Map<FieldName, Number> arguments) { ProbabilityDistribution probabilityDistribution = getProbabilityDistribution(arguments); return probabilityDistribution.getValues(); } // 預測結果分類 public Object predict(Map<FieldName, ?> arguments) { ProbabilityDistribution probabilityDistribution = getProbabilityDistribution(arguments); return probabilityDistribution.getPrediction(); } public static void main(String[] args) { ClassificationModel clf = new ClassificationModel("D:/JupyterSpace/RandomForestClassifier_Iris.pmml"); //這里模型地址 List<String> featureNames = clf.getFeatureNames(); System.out.println("feature: " + featureNames); // 構建待預測數據 Map<FieldName, Number> waitPreSample = new HashMap<>();
#這里的key一定要對應python中的列名 waitPreSample.put(new FieldName("sepal length (cm)"), 10); waitPreSample.put(new FieldName("sepal width (cm)"), 1); waitPreSample.put(new FieldName("petal length (cm)"), 3); waitPreSample.put(new FieldName("petal width (cm)"), 2); System.out.println("waitPreSample predict result: " + clf.predict(waitPreSample).toString()); System.out.println("waitPreSample predictProba result: " + clf.predictProba(waitPreSample).toString()); } }
注意事項:
1、類名和文件名要一致
2、打開File ==》Project Structure
看你的JDK版本和這里是否一致
運行程序,查看是否報錯。
這是我報的一個錯:
NoClassDefFoundError: javax/activation/DataSource
解決方法是下載:activation.jar包。
下載地址:
鏈接:https://pan.baidu.com/s/14D8cQWIJp2d7h2iljAPZ2A
提取碼:6f37
應該沒什么問題了。有問題請留言,一定回復。(有問題一定要告訴我,以后還要用呢。。。)