1、概述
ML中的一項重要任務是模型選擇,或使用數據為給定任務找到最佳模型或參數。這也稱為tuning。
可以針對單個估算器(例如LogisticRegression)進行調整,也可以針對包括多個算法,特征化和其他步驟的整個管道進行調整。用戶可以一次調整整個管道,而不必分別調整管道中的每個元素。
MLlib使用諸如CrossValidator和TrainValidationSplit之類的工具支持模型選擇。這些工具需要以下各項:
在較高級別,這些模型選擇工具的工作方式如下:
- 他們將輸入數據分為單獨的訓練和測試數據集。
- 對於每對(訓練,測試),它們都會遍歷一組ParamMap:對於每個ParamMap,他們使用這些參數擬合Estimator,獲得擬合的Model,然后使用Evaluator評估Model的性能。
- 他們選擇由性能最佳的參數集生成的模型。
該評估器可以是用於回歸問題的RegressionEvaluator,用於二元數據的BinaryClassificationEvaluator或用於多元問題的MulticlassClassificationEvaluator。
每個評估器中的setMetricName方法都可以覆蓋用於選擇最佳ParamMap的默認度量。
2、Cross-Validation交叉驗證
CrossValidator首先將數據集分成一組折疊,這些折疊用作單獨的訓練和測試數據集。例如,k = 3
折疊后,CrossValidator將生成3個(訓練,測試)數據集對,每個對都使用2/3的數據進行訓練,並使用1/3的數據進行測試。
為了評估特定的ParamMap,CrossValidator為3個模型(通過將Estimator擬合到3個不同的(訓練,測試)數據集對上)計算出平均評估指標。
確定最佳的ParamMap之后,CrossValidator最終使用最佳的ParamMap和整個數據集重新擬合Estimator。
3、適用情況
當數據集比較小的時候
交叉驗證可以“充分利用”有限的數據找到合適的模型參數,防止過度擬合
一般做深度學習跑標准數據集的時候用不到
4、code
package com.home.spark.ml import org.apache.spark.SparkConf import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.feature.{HashingTF, Tokenizer} import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.ml.linalg.Vector /** * @Description: 交叉驗證選擇最佳模型參數 * 請注意,在參數網格上進行交叉驗證的成本很高。 * 例如,在下面的示例中,參數網格具有3個值的hashingTF.numFeatures和2個值的lr.regParam,而CrossValidator使用2次折疊。這乘以(3×2)×2 = 12 * 訓練不同的模型。在實際設置中,嘗試更多的參數並使用更多的折疊數(通常是k = 3和k = 10)是很常見的。 * 換句話說,使用CrossValidator可能非常昂貴。但是,這也是一種公認的用於選擇參數的方法,該方法在統計上比啟發式手動調整更合理。 **/ object Ex_CrossValidator { def main(args: Array[String]): Unit = { val conf = new SparkConf(true).setAppName("spark ml model selection").setMaster("local[2]") val spark = SparkSession.builder().config(conf).getOrCreate() // import spark.implicits._ // Prepare training data from a list of (id, text, label) tuples. val training = spark.createDataFrame(Seq( (0L, "a b c d e spark", 1.0), (1L, "b d", 0.0), (2L, "spark f g h", 1.0), (3L, "hadoop mapreduce", 0.0), (4L, "b spark who", 1.0), (5L, "g d a y", 0.0), (6L, "spark fly", 1.0), (7L, "was mapreduce", 0.0), (8L, "e spark program", 1.0), (9L, "a e c l", 0.0), (10L, "spark compile", 1.0), (11L, "hadoop software", 0.0) )).toDF("id", "text", "label") // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. val tokenizer = new Tokenizer() .setInputCol("text") .setOutputCol("words") val hashingTF = new HashingTF() .setInputCol(tokenizer.getOutputCol) .setOutputCol("features") val lr = new LogisticRegression() .setMaxIter(10) val pipeline = new Pipeline() .setStages(Array(tokenizer, hashingTF, lr)) // We use a ParamGridBuilder to construct a grid of parameters to search over. // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. val paramGrid = new ParamGridBuilder() .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) .addGrid(lr.regParam, Array(0.1, 0.01)) .build() // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. // This will allow us to jointly choose parameters for all Pipeline stages. // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. // Note that the evaluator here is a BinaryClassificationEvaluator and its default metric // is areaUnderROC. val cv = new CrossValidator() .setEstimator(pipeline) .setEvaluator(new BinaryClassificationEvaluator) .setEstimatorParamMaps(paramGrid) .setNumFolds(2) // Use 3+ in practice .setParallelism(2) // Evaluate up to 2 parameter settings in parallel // Run cross-validation, and choose the best set of parameters. val cvModel = cv.fit(training) // Prepare test documents, which are unlabeled (id, text) tuples. val test = spark.createDataFrame(Seq( (4L, "spark i j k"), (5L, "l m n"), (6L, "mapreduce spark"), (7L, "apache hadoop") )).toDF("id", "text") // Make predictions on test documents. cvModel uses the best model found (lrModel). cvModel.transform(test) .select("id", "text", "probability", "prediction") .collect() .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => println(s"($id, $text) --> prob=$prob, prediction=$prediction") } spark.stop() } }
result:
(4, spark i j k) --> prob=[0.25806842225846466,0.7419315777415353], prediction=1.0 (5, l m n) --> prob=[0.9185597412653913,0.08144025873460858], prediction=0.0 (6, mapreduce spark) --> prob=[0.43203205663918753,0.5679679433608125], prediction=1.0 (7, apache hadoop) --> prob=[0.6766082856652199,0.32339171433478003], prediction=0.0