SparkMLib分類算法之朴素貝葉斯分類
(一)朴素貝葉斯分類理解
朴素貝葉斯法是基於貝葉斯定理與特征條件獨立假設的分類方法。簡單來說,朴素貝葉斯分類器假設樣本每個特征與其他特征都不相關。舉個例子,如果一種水果具有紅,圓,直徑大概4英寸等特征,該水果可以被判定為是蘋果。盡管這些特征相互依賴或者有些特征由其他特征決定,然而朴素貝葉斯分類器認為這些屬性在判定該水果是否為蘋果的概率分布上獨立的。盡管是帶着這些朴素思想和過於簡單化的假設,但朴素貝葉斯分類器在很多復雜的現實情形中仍能夠取得相當好的效果。朴素貝葉斯分類器的一個優勢在於只需要根據少量的訓練數據估計出必要的參數(離散型變量是先驗概率和類條件概率,連續型變量是變量的均值和方差)。
實例講解:
從該數據集計算得到的先驗概率以及每個離散屬性的類條件概率、連續屬性的類條件概率分布的參數(樣本均值和方差)如下:
先驗概率:P(Yes)=0.3;P(No)=0.7
P(有房=是|No) = 3/7
P(有房=否|No) = 4/7
P(有房=是|Yes) = 0
P(有房=否|Yes) = 1
P(婚姻狀況=單身|No) = 2/7
P(婚姻狀況=離婚|No) = 1/7
P(婚姻狀況=已婚|No) = 4/7
P(婚姻狀況=單身|Yes) = 2/3
P(婚姻狀況=離婚|Yes) = 1/3
P(婚姻狀況=已婚|Yes) = 0
年收入:
如果類=No:樣本均值=110; 樣本方差=2975
如果類=Yes:樣本均值=90; 樣本方差=25
——》待預測記錄:X={有房=否,婚姻狀況=已婚,年收入=120K}
P(No)*P(有房=否|No)*P(婚姻狀況=已婚|No)*P(年收入=120K|No)=0.7*4/7*4/7*0.0072=0.0024
P(Yes)*P(有房=否|Yes)*P(婚姻狀況=已婚|Yes)*P(年收入=120K|Yes)=0.3*1*0*1.2*10-9=0
由於0.0024大於0,所以該記錄分類為No。
從上面的例子可以看出,如果有一個屬性的類條件概率等於0,則整個類的后驗概率就等於0。僅僅使用記錄比例來估計類條件概率的方法顯得太脆弱了,尤其是當訓練樣例很少而屬性數目又很多時。解決該問題的方法是使用m估計方法來估計條件概率:
(二),SparkMLlib實現朴素貝葉斯算法應用
1,數據集下載: http://www.kaggle.com/c/stumbleupon/data 中的(train.txt和test.txt
2,數據集預處理
1,去除第一行:sed 1d train.tsv >train_nohead.tsv
2,去除干擾數據及處理數據不全等情況,從而獲取訓練數據集:
val orig_file=sc.textFile("train_nohead.tsv")
val ndata_file=orig_file.map(_.split("\t")).map{
r =>
val trimmed =r.map(_.replace("\"",""))
val lable=trimmed(r.length-1).toDouble
val feature=trimmed.slice(4,r.length-1).map(d => if(d=="?")0.0
else d.toDouble).map(d =>if(d<0) 0.0 else d)
LabeledPoint(lable,Vectors.dense(feature))
}.randomSplit(Array(0.7,0.3),11L)//划分為訓練和測試數據集
val ndata_train=ndata_file(0).cache()//訓練集
val ndata_test=ndata_file(1)//測試集
3,訓練貝葉斯模型,及評估模型(精確值,PR曲線,ROC曲線)
val model_NB=NaiveBayes.train(ndata_train) /*貝葉斯分類結果的正確率*/ val correct_NB=ndata_train.map{ point => if(model_NB.predict(point.features)==point.label) 1 else 0 }.sum()/ndata_train.count()//0.565959409594096 /*准確率 - 召回率( PR )曲線*和ROC 曲線輸出*/ val metricsNb=Seq(model_NB).map{ model => val socreAndLabels=ndata_train.map { point => (model.predict(point.features), point.label) } val metrics=new BinaryClassificationMetrics(socreAndLabels) (model.getClass.getSimpleName,metrics.areaUnderPR(),metrics.areaUnderROC()) } metricsNb.foreach{ case (m, pr, roc) => println(f"$m, Area under PR: ${pr * 100.0}%2.4f%%, Area under ROC: ${roc * 100.0}%2.4f%%") } /*NaiveBayesModel, Area under PR: 68.0851%, Area under ROC: 58.3559%*/
4,模型調優
1,改變特征值得選取,選取文本特征使用(1-of-k)方法
/*新特征,選取第三列文本特征*/ val categories = orig_file.map(_.split("\t")).map(r => r(3)).distinct.collect.zipWithIndex.toMap val dataNB = orig_file.map(_.split("\t")).map { r => val trimmed = r.map(_.replaceAll("\"", "")) val label = trimmed(r.length - 1).toInt val categoryIdx = categories(r(3)) val categoryFeatures = Array.ofDim[Double](categories.size) categoryFeatures(categoryIdx) = 1.0 LabeledPoint(label, Vectors.dense(categoryFeatures)) }.randomSplit(Array(0.7,0.3),11L)
val dataNB_train=dataNB(0)
val dataNB_test=dataNB(1)/*訓練朴素貝葉斯*/
val model_NB=NaiveBayes.train(dataNB_train) /*貝葉斯分類結果的正確率*/ val correct_NB=dataNB_test.map{ point => if(model_NB.predict(point.features)==point.label) 1 else 0 }.sum()/dataNB_test.count()//0.6111623616236163 /*PR曲線和AOC曲線*/ val metricsNb=Seq(model_NB).map{ model => val socreAndLabels=dataNB_test.map { point => (model.predict(point.features), point.label) } val metrics=new BinaryClassificationMetrics(socreAndLabels) (model.getClass.getSimpleName,metrics.areaUnderPR(),metrics.areaUnderROC()) } MetricsNb.foreach{ case (m, pr, roc) => println(f"$m, Area under PR: ${pr * 100.0}%2.4f%%, Area under ROC: ${roc * 100.0}%2.4f%%") } /*NaiveBayesModel, Area under PR: 74.8977%, Area under ROC: 60.1735%*/
2,修改參數,效果不是很明顯
/*改變label值*/ def trainNBWithParams(input: RDD[LabeledPoint], lambda: Double) = { val nb = new NaiveBayes nb.setLambda(lambda) nb.run(input) } val nbResults = Seq(0.001, 0.01, 0.1, 1.0, 10.0).map { param => val model = trainNBWithParams(dataNB_train, param) val scoreAndLabels = dataNB_test.map { point => (model.predict(point.features), point.label) } val metrics = new BinaryClassificationMetrics(scoreAndLabels) (s"$param lambda", metrics.areaUnderROC) } nbResults.foreach { case (param, auc) => println(f"$param, AUC = ${auc * 100}%2.2f%%") }
/*results
0.001 lambda, AUC = 60.17%
0.01 lambda, AUC = 60.17%
0.1 lambda, AUC = 60.17%
1.0 lambda, AUC = 60.17%
10.0 lambda, AUC = 60.17%
*/
參考網址: