機器學習——Java調用sklearn生成好的Logistic模型進行鳶尾花的預測


  機器學習是python語言的長處,而Java在web開發方面更具有優勢,如何通過java來調用python中訓練好的模型進行在線的預測呢?在java語言中去調用python構建好的模型主要有三種方法:

  1.在Java語言中,通過python的解釋器執行python代碼,簡單來說就是在java中通過python解釋器對象,傳入寫好的python代碼,進行執行,這樣的方式運行的效率非常低,而且存在很多python包無法使用的情況,只適合做簡單的python代碼的運行,並不推薦使用。

  2.通過PMML工具,將在sklearn中訓練好的模型生成一個pmml格式的文件,在該文件中,主要包含了模型的一些訓練好的參數,以及輸入數據的格式和名稱等信息。生成了pmml文件之后,在java中導入pmml相關的包,我們就能通過pmml相關的類讀取生成的pmml文件,使用其中的方法傳入指定的參數就能實現模型的預測,速度快,效果不錯。

  3.第二種方法因為模型已經訓練好了,無法改變,不能實現在線調參的功能,我們可以通過socket服務來進行python和java之間的網絡通信,python提供socket服務,java端將模型的參數通過網絡傳給python端,python端接受到參數之后,進行模型的訓練,訓練完成之后,將得到的結果返回給Java端。

  下面給是使用pmml方式調用的步驟:

  1.在python端生成pmml模型文件,下面以logistic回歸為例

    x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.85, random_state=1)
    model = PMMLPipeline([('LogisticModer', LogisticRegression())])
    model.fit(x_train, y_train)
    y_hat = model.predict(x_test)
    loss = y_hat == y_test
    accuracy = np.mean(loss)
    print(accuracy)
    sklearn2pmml(model, '.\LogisticRegression.pmml', with_repr=True)

  需要加載的包

from sklearn2pmml import sklearn2pmml
from sklearn2pmml.pipeline import PMMLPipeline

  我們使用PMMLPipeline()的管道函數,還可以在管道中加入其它的一些預處理的操作,比如歸一化。sklearn2pmml()函數能夠將訓練好的模型生成pmml文件,下面來看生成的pmml文件是怎樣的吧:

  下面,我們建一個JavaWeb工程:

 1         <dependency>
 2             <groupId>org.jpmml</groupId>
 3             <artifactId>pmml-evaluator</artifactId>
 4             <version>1.4.1</version>
 5         </dependency>
 6 
 7 
 8         <dependency>
 9             <groupId>org.jpmml</groupId>
10             <artifactId>pmml-evaluator-extension</artifactId>
11             <version>1.4.1</version>
12         </dependency>
13         <dependency>

  在maven中引入相關的依賴,我們將要用到的方法進行封裝,制作成一個工具類:

public static PMML getPMMLModel(InputStream inputStream) {
        PMML pmml = new PMML();
        try {
            pmml = org.jpmml.model.PMMLUtil.unmarshal(inputStream);
        } catch (SAXException e1) {
            e1.printStackTrace();
        } catch (JAXBException e2) {
            e2.printStackTrace();
        } finally {
            try {
                inputStream.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
            return pmml;
        }
    }


    public static Evaluator loadPmmlAndgetEvaluator(MachineLearnType machineLearnType) {

            String modefile = getJpmmlModelPath(machineLearnType);  //獲取模型的pmml文件路徑

            InputStream inputStream = readPmmlFile(modefile);  //根據文件路徑返回輸入流

            PMML pmml = getPMMLModel(inputStream);  //根據輸入流返回PMML

            ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();  //獲取 ModelEvaluatorFactory

            Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);  // 根據 PMML 模型返回 Evaluator 對象

            pmml = null;

            return evaluator;
    }

    public static Map<String, Object> modelPrediction(Evaluator evaluator, Map<String, Object> paramData) {
        if (evaluator == null || paramData == null) {
            System.out.println("--------------傳入對象 evaluator 或 dataMap 為空, 無法進行預測----------------");
            return null;
        }

        List<InputField> inputFields = evaluator.getInputFields();   //獲取模型的輸入域
        Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();

        for (InputField inputField : inputFields) {            //將參數通過模型對應的名稱進行添加
            FieldName inputFieldName = inputField.getName();   //獲取模型中的參數名
            Object paramValue = paramData.get(inputFieldName.getValue());   //獲取模型參數名對應的參數值
            FieldValue fieldValue = inputField.prepare(paramValue);   //將參數值填入模型中的參數中
            arguments.put(inputFieldName, fieldValue);          //存放在map列表中
        }
        Map<FieldName, ?> results = evaluator.evaluate(arguments);
        List<TargetField> targetFields = evaluator.getTargetFields();

        Map<String, Object> resultMap = new HashMap<>();

        for(TargetField targetField : targetFields) {
            FieldName targetFieldName = targetField.getName();
            Object targetFieldValue = results.get(targetFieldName);
            if (targetFieldValue instanceof Computable) {
                Computable computable = (Computable) targetFieldValue;
                resultMap.put(targetFieldName.getValue(), computable.getResult());
            }else {
                resultMap.put(targetFieldName.getValue(), targetFieldValue);
            }
        }
        return resultMap;
    }

  上述的方法中,我們將生成的pmml文件讀取,得到InputStream對象,調用上述的方法就行了。上面的代碼中,MachineLearnType的作用就是獲取pmml的路徑,我們將要輸入的參數放入Map中,進行預測,最后返回預測結果的Map,下面來看Service層的代碼,其中MachineLearnType.LOGISTIC_REGRESSION就是根據名稱獲取pmml文件:

Evaluator evaluator = JPmmlModelUtil.loadPmmlAndgetEvaluator(MachineLearnType.LOGISTIC_REGRESSION);
Map<String , Object> results = JPmmlModelUtil.modelPrediction(evaluator, paramMap);
int result =(int)((double)results.get("y"));

  下面是Controller層的代碼:

  /**
     * 使用pmml方式對輸入的參數進行線性回歸預測
     */
    @PostMapping("/logispmml")
    public ServerResponse<String> IrisLogosPmmlPredict(@RequestParam @Valid double x1,
                                                         @RequestParam @Valid double x2,
                                                         @RequestParam @Valid double x3,
                                                         @RequestParam @Valid double x4) {
        logger.info("x1: " + x1 + " x2: " + x2 + " x3:" + x3 + "x4:" + x4);
        Map<String, Object> paramMap = new HashMap<>();
        paramMap.put("x1", x1);
        paramMap.put("x2", x2);
        paramMap.put("x3", x3);
        paramMap.put("x4", x4);
        String result = logisticRegressionService.pridictlogisticpmml(paramMap);
        return createBySuccess(result);
    }

  我們生成的模型是logistic回歸進行鳶尾花數據集的分類,輸入的是樣本的四個特征,輸出是類別0,1,2

int result =(int)((double)results.get("y"));
String irisName = new String();
if(result == 0){
    irisName = "Iris-setosa";
}
if(result == 1){
    irisName = "Iris-versicolor";
}
if(result == 2){
    irisName = "Iris-virginica";
}
    return irisName;
}

  我們在service中將預測結果轉換為對應的類別,下面使用測試工具進行測試:

  我們就可以在python中將模型構建好,來進行調用啦!

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM