Spark快速獲得CrossValidator的最佳模型參數


Spark提供了便利的Pipeline模型,可以輕松的創建自己的學習模型。

但是大部分模型都是需要提供參數的,如果不提供就是默認參數,那么怎么選擇參數就是一個比較常見的問題。Spark提供在org.apache.spark.ml.tuning包下提供了模型選擇器,可以替換參數然后比較模型輸出。

目前有CrossValidator和TrainValidationSplit兩種,比如一個文本情感預測模型。

Pipeline只有三步,第一步切詞,第二步HashingTF,第三步NB分類

Pipeline pipeline = new Pipeline() .setStages(new PipelineStage[]{tokenizer, hashingTF, naiveBayes}); ParamMap[] paramMaps = new ParamGridBuilder() .addGrid(hashingTF.numFeatures(), new int[]{10000, 100000, 500000, 1000000}) .build(); CrossValidator cv = new CrossValidator() .setEstimator(pipeline) .setEvaluator(new BinaryClassificationEvaluator()) .setEstimatorParamMaps(paramMaps);

其中HashingTF的參數選擇非常重要,我們這里就隨便嘗試幾種,然后放在CrossValidator中去。

最后我們會獲得一個CrossValidatorModel類,這里有兩種選擇。

第一種是自己手動獲取其中的參數,因為bestModel的參數就是我們最后選擇的參數

Pipeline bestPipeline = (Pipeline) model.bestModel().parent(); PipelineStage stage = bestPipeline.getStages()[1]; stage.extractParamMap().get(stage.getParam("numFeatures"));

這種方法可以獲得值,但是需要根據你模型情況修改獲取的位置。

如果你只是想知道最佳參數是多少,並不是需要在上下文中使用,那還有一個更簡單的方法。

修改log4j的配置,添加

log4j.logger.org.apache.spark.ml.tuning.TrainValidationSplit=INFO log4j.logger.org.apache.spark.ml.tuning.CrossValidator=INFO

效果如下:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM