Spark提供了便利的Pipeline模型,可以輕松的創建自己的學習模型。
但是大部分模型都是需要提供參數的,如果不提供就是默認參數,那么怎么選擇參數就是一個比較常見的問題。Spark提供在org.apache.spark.ml.tuning包下提供了模型選擇器,可以替換參數然后比較模型輸出。
目前有CrossValidator和TrainValidationSplit兩種,比如一個文本情感預測模型。
Pipeline只有三步,第一步切詞,第二步HashingTF,第三步NB分類
Pipeline pipeline = new Pipeline() .setStages(new PipelineStage[]{tokenizer, hashingTF, naiveBayes}); ParamMap[] paramMaps = new ParamGridBuilder() .addGrid(hashingTF.numFeatures(), new int[]{10000, 100000, 500000, 1000000}) .build(); CrossValidator cv = new CrossValidator() .setEstimator(pipeline) .setEvaluator(new BinaryClassificationEvaluator()) .setEstimatorParamMaps(paramMaps);
其中HashingTF的參數選擇非常重要,我們這里就隨便嘗試幾種,然后放在CrossValidator中去。
最后我們會獲得一個CrossValidatorModel類,這里有兩種選擇。
第一種是自己手動獲取其中的參數,因為bestModel的參數就是我們最后選擇的參數
Pipeline bestPipeline = (Pipeline) model.bestModel().parent(); PipelineStage stage = bestPipeline.getStages()[1]; stage.extractParamMap().get(stage.getParam("numFeatures"));
這種方法可以獲得值,但是需要根據你模型情況修改獲取的位置。
如果你只是想知道最佳參數是多少,並不是需要在上下文中使用,那還有一個更簡單的方法。
修改log4j的配置,添加
log4j.logger.org.apache.spark.ml.tuning.TrainValidationSplit=INFO log4j.logger.org.apache.spark.ml.tuning.CrossValidator=INFO
效果如下: