機器學習是python語言的長處,而Java在web開發方面更具有優勢,如何通過java來調用python中訓練好的模型進行在線的預測呢?在java語言中去調用python構建好的模型主要有三種方法:
1.在Java語言中,通過python的解釋器執行python代碼,簡單來說就是在java中通過python解釋器對象,傳入寫好的python代碼,進行執行,這樣的方式運行的效率非常低,而且存在很多python包無法使用的情況,只適合做簡單的python代碼的運行,並不推薦使用。
2.通過PMML工具,將在sklearn中訓練好的模型生成一個pmml格式的文件,在該文件中,主要包含了模型的一些訓練好的參數,以及輸入數據的格式和名稱等信息。生成了pmml文件之后,在java中導入pmml相關的包,我們就能通過pmml相關的類讀取生成的pmml文件,使用其中的方法傳入指定的參數就能實現模型的預測,速度快,效果不錯。
3.第二種方法因為模型已經訓練好了,無法改變,不能實現在線調參的功能,我們可以通過socket服務來進行python和java之間的網絡通信,python提供socket服務,java端將模型的參數通過網絡傳給python端,python端接受到參數之后,進行模型的訓練,訓練完成之后,將得到的結果返回給Java端。
下面給是使用pmml方式調用的步驟:
1.在python端生成pmml模型文件,下面以logistic回歸為例
x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.85, random_state=1) model = PMMLPipeline([('LogisticModer', LogisticRegression())]) model.fit(x_train, y_train) y_hat = model.predict(x_test) loss = y_hat == y_test accuracy = np.mean(loss) print(accuracy) sklearn2pmml(model, '.\LogisticRegression.pmml', with_repr=True)
需要加載的包
from sklearn2pmml import sklearn2pmml from sklearn2pmml.pipeline import PMMLPipeline
我們使用PMMLPipeline()的管道函數,還可以在管道中加入其它的一些預處理的操作,比如歸一化。sklearn2pmml()函數能夠將訓練好的模型生成pmml文件,下面來看生成的pmml文件是怎樣的吧:
下面,我們建一個JavaWeb工程:
1 <dependency> 2 <groupId>org.jpmml</groupId> 3 <artifactId>pmml-evaluator</artifactId> 4 <version>1.4.1</version> 5 </dependency> 6 7 8 <dependency> 9 <groupId>org.jpmml</groupId> 10 <artifactId>pmml-evaluator-extension</artifactId> 11 <version>1.4.1</version> 12 </dependency> 13 <dependency>
在maven中引入相關的依賴,我們將要用到的方法進行封裝,制作成一個工具類:
public static PMML getPMMLModel(InputStream inputStream) { PMML pmml = new PMML(); try { pmml = org.jpmml.model.PMMLUtil.unmarshal(inputStream); } catch (SAXException e1) { e1.printStackTrace(); } catch (JAXBException e2) { e2.printStackTrace(); } finally { try { inputStream.close(); } catch (IOException e) { e.printStackTrace(); } return pmml; } } public static Evaluator loadPmmlAndgetEvaluator(MachineLearnType machineLearnType) { String modefile = getJpmmlModelPath(machineLearnType); //獲取模型的pmml文件路徑 InputStream inputStream = readPmmlFile(modefile); //根據文件路徑返回輸入流 PMML pmml = getPMMLModel(inputStream); //根據輸入流返回PMML ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance(); //獲取 ModelEvaluatorFactory Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml); // 根據 PMML 模型返回 Evaluator 對象 pmml = null; return evaluator; } public static Map<String, Object> modelPrediction(Evaluator evaluator, Map<String, Object> paramData) { if (evaluator == null || paramData == null) { System.out.println("--------------傳入對象 evaluator 或 dataMap 為空, 無法進行預測----------------"); return null; } List<InputField> inputFields = evaluator.getInputFields(); //獲取模型的輸入域 Map<FieldName, FieldValue> arguments = new LinkedHashMap<>(); for (InputField inputField : inputFields) { //將參數通過模型對應的名稱進行添加 FieldName inputFieldName = inputField.getName(); //獲取模型中的參數名 Object paramValue = paramData.get(inputFieldName.getValue()); //獲取模型參數名對應的參數值 FieldValue fieldValue = inputField.prepare(paramValue); //將參數值填入模型中的參數中 arguments.put(inputFieldName, fieldValue); //存放在map列表中 } Map<FieldName, ?> results = evaluator.evaluate(arguments); List<TargetField> targetFields = evaluator.getTargetFields(); Map<String, Object> resultMap = new HashMap<>(); for(TargetField targetField : targetFields) { FieldName targetFieldName = targetField.getName(); Object targetFieldValue = results.get(targetFieldName); if (targetFieldValue instanceof Computable) { Computable computable = (Computable) targetFieldValue; resultMap.put(targetFieldName.getValue(), computable.getResult()); }else { resultMap.put(targetFieldName.getValue(), targetFieldValue); } } return resultMap; }
上述的方法中,我們將生成的pmml文件讀取,得到InputStream對象,調用上述的方法就行了。上面的代碼中,MachineLearnType的作用就是獲取pmml的路徑,我們將要輸入的參數放入Map中,進行預測,最后返回預測結果的Map,下面來看Service層的代碼,其中MachineLearnType.LOGISTIC_REGRESSION就是根據名稱獲取pmml文件:
Evaluator evaluator = JPmmlModelUtil.loadPmmlAndgetEvaluator(MachineLearnType.LOGISTIC_REGRESSION); Map<String , Object> results = JPmmlModelUtil.modelPrediction(evaluator, paramMap); int result =(int)((double)results.get("y"));
下面是Controller層的代碼:
/** * 使用pmml方式對輸入的參數進行線性回歸預測 */ @PostMapping("/logispmml") public ServerResponse<String> IrisLogosPmmlPredict(@RequestParam @Valid double x1, @RequestParam @Valid double x2, @RequestParam @Valid double x3, @RequestParam @Valid double x4) { logger.info("x1: " + x1 + " x2: " + x2 + " x3:" + x3 + "x4:" + x4); Map<String, Object> paramMap = new HashMap<>(); paramMap.put("x1", x1); paramMap.put("x2", x2); paramMap.put("x3", x3); paramMap.put("x4", x4); String result = logisticRegressionService.pridictlogisticpmml(paramMap); return createBySuccess(result); }
我們生成的模型是logistic回歸進行鳶尾花數據集的分類,輸入的是樣本的四個特征,輸出是類別0,1,2
int result =(int)((double)results.get("y")); String irisName = new String(); if(result == 0){ irisName = "Iris-setosa"; } if(result == 1){ irisName = "Iris-versicolor"; } if(result == 2){ irisName = "Iris-virginica"; } return irisName; }
我們在service中將預測結果轉換為對應的類別,下面使用測試工具進行測試:
我們就可以在python中將模型構建好,來進行調用啦!