1、概念
邏輯回歸是預測分類相應的常用方法。廣義線性回歸的一個特例是預測結果的概率。在spark.ml邏輯回歸中,可以使用二項邏輯回歸來預測二元結果,
或者可以使用多項邏輯回歸來預測多類結果。使用該family參數在這兩種算法之間選擇,或者保持不設置(缺省auto),Spark將推斷出正確的變量。 通過將family參數設置為“多項式”,可以將多項邏輯回歸用於二進制分類。它將產生兩組系數和兩個截距.
在分類問題中,我們嘗試預測的是結果是否屬於某一個類(例如正確或錯誤)。分類問題的例子有:判斷一封電子郵件是否是垃圾郵件;判斷一次金融交易是否是欺詐;
2、code,參考地址:https://github.com/asker124143222/spark-demo
package com.home.spark.ml import org.apache.spark.SparkConf import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.sql.{Dataset, Row, SparkSession} /** * @Description: 邏輯回歸,二項分類預測 * **/ object Ex_BinomialLogisticRegression { def main(args: Array[String]): Unit = { val conf = new SparkConf(true).setMaster("local[*]").setAppName("spark ml label") val spark = SparkSession.builder().config(conf).getOrCreate() //rdd轉換成df或者ds需要SparkSession實例的隱式轉換 //導入隱式轉換,注意這里的spark不是包名,而是SparkSession的對象名 import spark.implicits._ val data = spark.sparkContext.textFile("input/iris.data.txt") .map(_.split(",")) .map(a => Iris( Vectors.dense(a(0).toDouble, a(1).toDouble, a(2).toDouble, a(3).toDouble), a(4)) ).toDF() data.show() data.createOrReplaceTempView("iris") val TotalCount = spark.sql("select count(*) from iris") println("記錄數: " + TotalCount.collect().take(1).mkString) //二項預測,由於樣本數據有三類數據,排除Iris-setosa val df = spark.sql("select * from iris where label!='Iris-setosa'") df.map(r => r(1) + " : " + r(0)).collect().take(10).foreach(println) println("過濾后的記錄數: " + df.count()) /* VectorIndexer 提高決策樹或隨機森林等ML方法的分類效果。 VectorIndexer是對數據集特征向量中的類別(離散值)特征(index categorical features categorical features )進行編號。 它能夠自動判斷那些特征是離散值型的特征,並對他們進行編號, 具體做法是通過設置一個maxCategories,特征向量中某一個特征不重復取值個數小於maxCategories,則被重新編號為0~K(K<=maxCategories-1)。 某一個特征不重復取值個數大於maxCategories,則該特征視為連續值,不會重新編號(不會發生任何改變) 假設maxCategories=5,那么特征列中非重復取值小於等於5的列將被重新索引 為了索引的穩定性,規定如果這個特征值為0,則一定會被編號成0,這樣可以保證向量的稀疏度 maxCategories缺省是20 */ //對特征列和標簽列進行索引轉換 val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(df) val featureIndexer = new VectorIndexer() // .setMaxCategories(5) //設置為5后,由於特征列的非重復值個數都大於5,所以不會發生任何轉換,也就沒有意義 .setInputCol("features").setOutputCol("indexedFeatures") .fit(df) //對原數據集划分訓練數據(70%)和測試數據(30%) val Array(trainingData, testData): Array[Dataset[Row]] = df.randomSplit(Array(0.7, 0.3)) /** * LR建模 * setMaxIter設置最大迭代次數(默認100),具體迭代次數可能在不足最大迭代次數停止 * setTol設置容錯(默認1e-6),每次迭代會計算一個誤差,誤差值隨着迭代次數增加而減小,當誤差小於設置容錯,則停止迭代 * setRegParam設置正則化項系數(默認0),正則化主要用於防止過擬合現象,如果數據集較小,特征維數又多,易出現過擬合,考慮增大正則化系數 * setElasticNetParam正則化范式比(默認0),正則化有兩種方式:L1(Lasso)和L2(Ridge),L1用於特征的稀疏化,L2用於防止過擬合 * setLabelCol設置標簽列 * setFeaturesCol設置特征列 * setPredictionCol設置預測列 * setThreshold設置二分類閾值 */ //設置邏輯回歸參數 val lr = new LogisticRegression().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setFamily() .setMaxIter(100).setRegParam(0.3).setElasticNetParam(0.8) //轉換器,將預測的類別重新轉成字符型 val labelConverter = new IndexToString() .setInputCol("prediction") .setOutputCol("predectionLabel") .setLabels(labelIndexer.labels) //建立工作流 val lrPipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, lr, labelConverter)) //生成模型 val model = lrPipeline.fit(trainingData) //預測 val result = model.transform(testData) //打印結果 result.show(200, false) //模型評估,預測准確性和錯誤率 val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction") val lrAccuracy: Double = evaluator.evaluate(result) println("Test Error = " + (1.0 - lrAccuracy)) spark.stop() } } case class Iris(features: Vector, label: String)
3、result
+-----------------+---------------+------------+-------------------+--------------------------------------------+----------------------------------------+----------+---------------+ |features |label |indexedLabel|indexedFeatures |rawPrediction |probability |prediction|predectionLabel| +-----------------+---------------+------------+-------------------+--------------------------------------------+----------------------------------------+----------+---------------+ |[4.9,2.4,3.3,1.0]|Iris-versicolor|0.0 |[4.9,3.0,3.3,0.0] |[1.0071037675553336,-1.0071037675553336] |[0.7324529695042751,0.2675470304957249] |0.0 |Iris-versicolor| |[5.0,2.0,3.5,1.0]|Iris-versicolor|0.0 |[5.0,0.0,3.5,0.0] |[0.938177922699384,-0.938177922699384] |[0.7187314594034615,0.2812685405965385] |0.0 |Iris-versicolor| |[5.6,2.5,3.9,1.1]|Iris-versicolor|0.0 |[5.6,4.0,3.9,1.0] |[0.7107814076350716,-0.7107814076350716] |[0.6705737993354417,0.3294262006645583] |0.0 |Iris-versicolor| |[5.6,2.9,3.6,1.3]|Iris-versicolor|0.0 |[5.6,8.0,3.6,3.0] |[0.6350805242141693,-0.6350805242141693] |[0.6536405613705153,0.3463594386294846] |0.0 |Iris-versicolor| |[5.8,2.7,4.1,1.0]|Iris-versicolor|0.0 |[5.8,6.0,4.1,0.0] |[0.7314003881315354,-0.7314003881315354] |[0.6751125028597408,0.32488749714025916]|0.0 |Iris-versicolor| |[6.1,2.8,4.7,1.2]|Iris-versicolor|0.0 |[6.1,7.0,4.7,2.0] |[0.34553320285886,-0.34553320285886] |[0.5855339747983552,0.41446602520164466]|0.0 |Iris-versicolor| |[6.2,2.2,4.5,1.5]|Iris-versicolor|0.0 |[6.2,1.0,4.5,5.0] |[0.14582457165756946,-0.14582457165756946] |[0.5363916772629104,0.46360832273708963]|0.0 |Iris-versicolor| |[6.4,2.9,4.3,1.3]|Iris-versicolor|0.0 |[6.4,8.0,4.3,3.0] |[0.39384006721834597,-0.39384006721834597] |[0.597206774507057,0.40279322549294305] |0.0 |Iris-versicolor| |[6.6,3.0,4.4,1.4]|Iris-versicolor|0.0 |[6.6,9.0,4.4,4.0] |[0.2698323194379575,-0.2698323194379575] |[0.5670517391689078,0.43294826083109217]|0.0 |Iris-versicolor| |[6.7,3.0,5.0,1.7]|Iris-versicolor|0.0 |[6.7,9.0,5.0,7.0] |[-0.20557969118713126,0.20557969118713126] |[0.44878532413929256,0.5512146758607075]|1.0 |Iris-virginica | |[6.7,3.1,4.4,1.4]|Iris-versicolor|0.0 |[6.7,10.0,4.4,4.0] |[0.2698323194379575,-0.2698323194379575] |[0.5670517391689078,0.43294826083109217]|0.0 |Iris-versicolor| |[7.0,3.2,4.7,1.4]|Iris-versicolor|0.0 |[7.0,11.0,4.7,4.0] |[0.16644355215403328,-0.16644355215403328] |[0.5415150896404186,0.4584849103595813] |0.0 |Iris-versicolor| |[4.9,2.5,4.5,1.7]|Iris-virginica |1.0 |[4.9,4.0,4.5,7.0] |[-0.033265079047257284,0.033265079047257284]|[0.49168449702809164,0.5083155029719083]|1.0 |Iris-virginica | |[5.4,3.0,4.5,1.5]|Iris-versicolor|0.0 |[5.4,9.0,4.5,5.0] |[0.14582457165756946,-0.14582457165756946] |[0.5363916772629104,0.46360832273708963]|0.0 |Iris-versicolor| |[5.6,2.8,4.9,2.0]|Iris-virginica |1.0 |[5.6,7.0,4.9,10.0] |[-0.43975124481639627,0.43975124481639627] |[0.39180024423019144,0.6081997557698086]|1.0 |Iris-virginica | |[5.6,3.0,4.1,1.3]|Iris-versicolor|0.0 |[5.6,9.0,4.1,3.0] |[0.4627659120742955,-0.4627659120742955] |[0.6136701219061476,0.38632987809385244]|0.0 |Iris-versicolor| |[5.8,2.7,3.9,1.2]|Iris-versicolor|0.0 |[5.8,6.0,3.9,2.0] |[0.6212365822826582,-0.6212365822826582] |[0.6504997376392441,0.34950026236075604]|0.0 |Iris-versicolor| |[5.8,2.7,5.1,1.9]|Iris-virginica |1.0 |[5.8,6.0,5.1,9.0] |[-0.419132264319932,0.419132264319932] |[0.3967244102962335,0.6032755897037665] |1.0 |Iris-virginica | |[5.9,3.0,5.1,1.8]|Iris-virginica |1.0 |[5.9,9.0,5.1,8.0] |[-0.32958743896751885,0.32958743896751885] |[0.4183410089972438,0.5816589910027563] |1.0 |Iris-virginica | |[6.0,2.9,4.5,1.5]|Iris-versicolor|0.0 |[6.0,8.0,4.5,5.0] |[0.14582457165756946,-0.14582457165756946] |[0.5363916772629104,0.46360832273708963]|0.0 |Iris-versicolor| |[6.1,3.0,4.6,1.4]|Iris-versicolor|0.0 |[6.1,9.0,4.6,4.0] |[0.20090647458200817,-0.20090647458200817] |[0.5500583546439539,0.4499416453560461] |0.0 |Iris-versicolor| |[6.2,3.4,5.4,2.3]|Iris-virginica |1.0 |[6.2,13.0,5.4,13.0]|[-0.8807003330135101,0.8807003330135101] |[0.29303267372325625,0.7069673262767437]|1.0 |Iris-virginica | |[6.7,3.1,4.7,1.5]|Iris-versicolor|0.0 |[6.7,10.0,4.7,5.0] |[0.07689872680162013,-0.07689872680162013] |[0.5192152136737482,0.48078478632625177]|0.0 |Iris-versicolor| |[6.7,3.3,5.7,2.5]|Iris-virginica |1.0 |[6.7,12.0,5.7,15.0]|[-1.163178751002261,1.163178751002261] |[0.23809016943453823,0.7619098305654617]|1.0 |Iris-virginica | |[6.8,3.0,5.5,2.1]|Iris-virginica |1.0 |[6.8,9.0,5.5,11.0] |[-0.7360736047366578,0.7360736047366578] |[0.32386333429517283,0.6761366657048272]|1.0 |Iris-virginica | |[6.9,3.1,5.4,2.1]|Iris-virginica |1.0 |[6.9,10.0,5.4,11.0]|[-0.7016106823086834,0.7016106823086834] |[0.33145521561995817,0.6685447843800418]|1.0 |Iris-virginica | |[7.2,3.6,6.1,2.5]|Iris-virginica |1.0 |[7.2,14.0,6.1,15.0]|[-1.3010304407141597,1.3010304407141597] |[0.21399164655179387,0.7860083534482062]|1.0 |Iris-virginica | |[7.7,2.8,6.7,2.0]|Iris-virginica |1.0 |[7.7,7.0,6.7,10.0] |[-1.0600838485199424,1.0600838485199424] |[0.2572934314622856,0.7427065685377143] |1.0 |Iris-virginica | |[7.7,3.0,6.1,2.3]|Iris-virginica |1.0 |[7.7,9.0,6.1,13.0] |[-1.1219407900093334,1.1219407900093334] |[0.24565146441425778,0.7543485355857422]|1.0 |Iris-virginica | |[7.9,3.8,6.4,2.0]|Iris-virginica |1.0 |[7.9,15.0,6.4,10.0]|[-0.9566950812360182,0.9566950812360182] |[0.2775403823663211,0.7224596176336789] |1.0 |Iris-virginica | +-----------------+---------------+------------+-------------------+--------------------------------------------+----------------------------------------+----------+---------------+ Test Error = 0.03314285714285714