使用java調用python訓練出的pmml模型


記錄下自己的過程,以后可以隨時用,如果能幫到大家就更好了。

從安裝軟件說起,嫌麻煩的就別看了。

一、下載工具(俗話說得好,預先善其事必先利其器!哈哈)

我剛開始安裝的是eclipse,但有諸多麻煩不能解決,就用了IDEA,和Pycharm一個公司發行的。

首先進入官網: http://www.jetbrains.com/products.html#lang=java

選擇IDEA下載:

 

由於社區版的功能太少,我下載的是企業版的,后邊會告訴破解方法。

IDEA的安裝教程網上都有,正常安裝就好。

企業版的激活碼大家可以關注一個公眾號,我也是在網上找到的。

http://idea.medeming.com/

關注公眾號后粘貼就行了。

二、Java環境安裝

參考教程:https://blog.csdn.net/weixin_38381149/article/details/89668578

寫博客時想找當時看的博客,但發現了這個很全的,jdk,maven,tomcat都有。

想當初我為了裝一個maven花了好久。。。

三、新建Maven項目

  File ==》New==》Project==》Maven

四、接下來在IDEA中配置Maven,這是當時參考的博客:https://www.cnblogs.com/jiangzhaowei/p/9534393.html

五、添加依賴

  由於我只是為了調用模型,沒有太多依賴,只添加了這么幾個

    <dependencies>

        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator</artifactId>
            <version>1.4.1</version>
        </dependency>
        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator-extension</artifactId>
            <version>1.4.1</version>
        </dependency>

        <dependency>
            <groupId>javax.xml.bind</groupId>
            <artifactId>jaxb-api</artifactId>
            <version>2.3.0</version>
        </dependency>
        <dependency>
            <groupId>com.sun.xml.bind</groupId>
            <artifactId>jaxb-core</artifactId>
            <version>2.3.0</version>
        </dependency>
        <dependency>
            <groupId>com.sun.xml.bind</groupId>
            <artifactId>jaxb-impl</artifactId>
            <version>2.3.0</version>
        </dependency>

    </dependencies>

六、java調用Python訓練出的pmml模型的代碼

import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import org.jpmml.model.PMMLUtil;
import org.xml.sax.SAXException;

import javax.xml.bind.JAXBException;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class ClassificationModel {
    private Evaluator modelEvaluator;

    /**
     * 通過傳入 PMML 文件路徑來生成機器學習模型
     *
     * @param pmmlFileName pmml 文件路徑
     */
    public ClassificationModel(String pmmlFileName) {
        PMML pmml = null;

        try {
            if (pmmlFileName != null) {
                InputStream is = new FileInputStream(pmmlFileName);
                pmml = PMMLUtil.unmarshal(is);
                try {
                    is.close();
                } catch (IOException e) {
                    System.out.println("InputStream close error!");
                }

                ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();

                this.modelEvaluator = (Evaluator) modelEvaluatorFactory.newModelEvaluator(pmml);
                modelEvaluator.verify();
                System.out.println("加載模型成功!");
            }
        } catch (SAXException e) {
            e.printStackTrace();
        } catch (JAXBException e) {
            e.printStackTrace();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        }

    }

    // 獲取模型需要的特征名稱
    public List<String> getFeatureNames() {
        List<String> featureNames = new ArrayList<String>();

        List<InputField> inputFields = modelEvaluator.getInputFields();

        for (InputField inputField : inputFields) {
            featureNames.add(inputField.getName().toString());
        }
        return featureNames;
    }

    // 獲取目標字段名稱
    public String getTargetName() {
        return modelEvaluator.getTargetFields().get(0).getName().toString();
    }

    // 使用模型生成概率分布
    private ProbabilityDistribution getProbabilityDistribution(Map<FieldName, ?> arguments) {
        Map<FieldName, ?> evaluateResult = modelEvaluator.evaluate(arguments);

        FieldName fieldName = new FieldName(getTargetName());

        return (ProbabilityDistribution) evaluateResult.get(fieldName);

    }

    // 預測不同分類的概率
    public ValueMap<String, Number> predictProba(Map<FieldName, Number> arguments) {
        ProbabilityDistribution probabilityDistribution = getProbabilityDistribution(arguments);
        return probabilityDistribution.getValues();
    }

    // 預測結果分類
    public Object predict(Map<FieldName, ?> arguments) {
        ProbabilityDistribution probabilityDistribution = getProbabilityDistribution(arguments);

        return probabilityDistribution.getPrediction();
    }

    public static void main(String[] args) {
        ClassificationModel clf = new ClassificationModel("D:/JupyterSpace/RandomForestClassifier_Iris.pmml"); //這里模型地址

        List<String> featureNames = clf.getFeatureNames();
        System.out.println("feature: " + featureNames);

        // 構建待預測數據
        Map<FieldName, Number> waitPreSample = new HashMap<>();
     #這里的key一定要對應python中的列名 waitPreSample.put(
new FieldName("sepal length (cm)"), 10); waitPreSample.put(new FieldName("sepal width (cm)"), 1); waitPreSample.put(new FieldName("petal length (cm)"), 3); waitPreSample.put(new FieldName("petal width (cm)"), 2); System.out.println("waitPreSample predict result: " + clf.predict(waitPreSample).toString()); System.out.println("waitPreSample predictProba result: " + clf.predictProba(waitPreSample).toString()); } }

注意事項:

1、類名和文件名要一致

2、打開File  ==》Project Structure

看你的JDK版本和這里是否一致

運行程序,查看是否報錯。

這是我報的一個錯:

NoClassDefFoundError: javax/activation/DataSource

  解決方法是下載:activation.jar包。

  下載地址:

    鏈接:https://pan.baidu.com/s/14D8cQWIJp2d7h2iljAPZ2A
    提取碼:6f37

應該沒什么問題了。有問題請留言,一定回復。(有問題一定要告訴我,以后還要用呢。。。)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM