MLlib--GBDT算法


轉載請標明出處http://www.cnblogs.com/haozhengfei/p/8b9cb1875288d9f6cfc2f5a9b2f10eac.html 


GBDT算法

江湖傳言:GBDT算法堪稱算法界的倚天劍屠龍刀
 
GBDT算法主要由三個部分組成:
 
                   – 
Regression Decistion Tree(即
DT)
   
回歸樹
                   – 
Gradient Boosting(即GB)
   
迭代提升
                   – 
Shrinkage(漸變)
   
漸變

1.決策樹

1.1決策樹的分類

決策樹 分類決策樹 用於分類標簽值,如晴天/陰天/霧/雨、用戶性別、網頁是否是垃圾頁面。
回歸決策樹 預測實數值,如明天的溫度、用戶的年齡、網頁的相關程度
強調:回歸決策樹的結果(數值)加減是有意義的,但是分類決策樹是沒有意義的,因為它是類別

1.2什么是回歸決策樹?

   回歸樹決策樹的總體流程可以類比分類樹,比如C4.5分類樹在每次分枝時,是窮舉每一個feature的每一
個閾值,找到使得按照feature<=閾值,和feature>閾值分成的兩個分枝 的熵最大的feature和閾值,按照該標准分枝得到兩個新節點,用同樣方法 繼續分枝直到所有人都被分入性別唯一的葉子節點,或達到預設的終止條
件,若最終葉子節點中的性別不唯一,則以多數人的性別作為該葉子節點 的性別。在回歸樹當中, 每個節點(不一定是葉子節點) 都會得 一個預測值,以年齡為例,該預測值等於屬於這個節點的所有人年齡的平
均值。分枝時窮舉每一個feature的每個閾值找最好的分割點,但 衡量最好 的標准不再是最大熵,而是最小化均方差--即(每個人的年齡-預測年齡) ^2 的總和 / N,或者說是每個人的預測誤差平方和 除以 N。這很好理解
被預測出錯的人數越多,錯的越離譜,均方差就越大,通過最小化均方 差能夠找到最靠譜的分枝依據。分枝直到每個葉子節點上人的年齡都唯一 (這太難了)或者達到預設的終止條件(如葉子個數上限),若最終葉子
節點上人的年齡不唯一,則以該節點上所有人的平均年齡做為該葉子節點的預測年齡。

1.3回歸決策樹划分的原則_CART算法

    CART算法思想:計算子節點的均方差,均方差越小,回歸決策樹越好。

2.GBDT算法_Boosting迭代

   即通過迭代多棵樹來共同決策
    GBDT的核心就在於,每一棵樹學的是之前所有樹結論和的殘差(比如A的真實年齡是18歲,但第一棵樹的預測年齡是12歲,差了6歲,即殘 差為6歲。那么在第二棵樹里我們把A的年齡設為6歲去學習,如果第二棵
樹真的能把A分到6歲的葉子節點,那累加兩棵樹的結論就是A的真實年齡 ;如果第二棵樹的結論是5歲,則A仍然存在1歲的殘差,第三棵樹里A的年 齡就變成1歲,繼續 學),這個殘 差就是一個加預測值后能得真實值的累加量。這一過程離不開Boosting迭代

2.2圖解Boosting迭代

 
Adaboost算法是另一種boost方法,它按分類對錯,分配不同的 weight,計算時使用這些weight,從而讓“錯分的樣本權重越來 越大,使它們更被重視”。
 

2.3GBDT算法_構建決策樹的步驟

• 0. 表示給定一個初始值
• 1. 表示建立M棵決策樹(迭代M次)
• 2. 表示對函數估計值F(x)進行Logistic變換
• 3. 表示對於K個分類進行下面的操作(其實這個for循環也可以理 解為向量的操作,每一個樣本點xi都對應了K種可能的分類yi,所 以yi, F(xi), p(xi)都是一個K維的向量,這樣或許容易理解一點)
• 4. 表示求得殘差減少的梯度方向
• 5. 表示根據每一個樣本點x,與其殘差減少的梯度方向,得到一棵 由J個葉子節點組成的決策樹
• 6. 為當決策樹建立完成后,通過這個公式,可以得到每一個葉子 節點的增益(這個增益在預測的時候用的)
每個增益的組成其實也是一個K維的向量,表示如果在決策樹預 測的過程中,如果某一個樣本點掉入了這個葉子節點,則其對應 的K個分類的值是多少。比如說,GBDT得到了三棵決策樹,一個 樣本點在預測的時候,也會掉入3個葉子節點上,其增益分別為( 假設為3分類的問題):  (0.5, 0.8, 0.1), (0.2, 0.6, 0.3), (0.4, 0.3, 0.3),那么這樣最終得到 的分類為第二個,因為選擇分類2的決策樹是最多的。
• 7. 將當前得到的決策樹與之前的那些決策樹合並起來 ,作為新的一個模型

2.4GBDT和其他的比較

2.4.1GBDT和隨機森林的比較

問題: GBDT和隨機森林都是基於決策樹的高級算法,都可以用來做分類和回歸,那么什么時候用GBDT? 什么時候用隨機森林?
     1.二者構建樹的差異:
                隨機森林采取有放回的抽樣構建的每棵樹基本是一樣的,多棵樹形成森林,采用投票機制決定最終的結果。
                GBDT通常只有第一個樹是完整的, 當預測值和真實值有一定差距時(殘差), 下一棵樹的構建會拿到上一棵樹最終的殘差作為當前樹的輸入。 GBDT每次關注的不是預測錯誤的樣本,沒有對錯一說,只有離標准相差的遠近。
     2.因為二者構建樹的差異,隨機森林采用有放回的抽樣進行構建決策樹,所以隨機森林相對於GBDT來說對於異常數據不是很敏感,但是GBDT不斷的關注殘差,導致最后的結果會非常的准確,不會出現欠擬合的情況,但是異常數據會干擾最后的決策。
    綜上所述:如果數據中異常值較多,那么采用隨機森林,否則采用GBDT。

2.4.2GBDT和SVM

GBDTSVM是最接近於神經網絡的算法,神經網絡每增加一層計算量呈幾何級增加,神經網絡在計算的時候倒着推,每得到一個結果,增加一些成分的權重,神經網絡內部就是通過不同的層次來訓練,然后增加比較重要的特征,降低那些沒有用對結果影響很小的維度的權重,這些過程在運行的時候都是內部自動做。如果 GBDT內部核函數是線性回歸(邏輯回歸),並且這些回歸的離散化,歸一化做得非常好,那么就可以趕得上神經網絡。 GBDT底層是線性組合來給我們做分類或者擬合,如果層次太深,或者迭代次數太多,就可能出現過擬合,比如原來用一條線分開的兩種數據,我們使用多條線來分類。

2.4.3如何用回歸決策樹來進行分類?

   把回歸決策樹最終的葉子節點上面的數據進行一個邏輯變換(logistic),然后對這樣的數據進行邏輯回歸,就可以使用回歸決策樹來進行分類了。

2.4.4數據處理--歸一化

歸一化:
   不同的特征數量級不一樣,(第一個特征0-1,第二個特征1000-10000,這時候就需要歸一化,都歸一到0-1之間)
 
歸一化兩種方式:
   線性歸一化 (x-min)/(max-min)
   零均值歸一化 (與期望和方差有關);

2.5回歸決策樹code

生成測試數據 LogisticRegressionDataGenerator
回歸決策樹代碼測試 GBDT_new
    bootstingStrategy.setLearningRate(0.8)`//設置梯度,迭代的快慢
 1 import org.apache.log4j.{Level, Logger}
 2 import org.apache.spark.mllib.feature.{StandardScaler, StandardScalerModel}
 3 import org.apache.spark.mllib.regression.LabeledPoint
 4 import org.apache.spark.mllib.tree.{GradientBoostedTrees, DecisionTree}
 5 import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
 6 import org.apache.spark.mllib.tree.impurity.Entropy
 7 import org.apache.spark.mllib.util.MLUtils
 8 import org.apache.spark.rdd.RDD
 9 import org.apache.spark.{SparkConf, SparkContext}
10 
11 /**
12   * Created by hzf
13   */
14 object GBDT_new {
15 //    E:\IDEA_Projects\mlib\data\GBDT\train E:\IDEA_Projects\mlib\data\GBDT\train\model 10 local
16     def main(args: Array[String]) {
17         Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
18         if (args.length < 4) {
19             System.err.println("Usage: DecisionTrees <inputPath> <modelPath> <maxDepth> <master> [<AppName>]")
20             System.err.println("eg: hdfs://192.168.57.104:8020/user/000000_0 10 0.1 spark://192.168.57.104:7077 DecisionTrees")
21             System.exit(1)
22         }
23         val appName = if (args.length > 4) args(4) else "DecisionTrees"
24         val conf = new SparkConf().setAppName(appName).setMaster(args(3))
25         val sc = new SparkContext(conf)
26 
27         val traindata: RDD[LabeledPoint] = MLUtils.loadLabeledPoints(sc, args(0))
28         val features = traindata.map(_.features)
29         val scaler: StandardScalerModel = new StandardScaler(withMean = true, withStd = true).fit(features)
30         val train: RDD[LabeledPoint] = traindata.map(sample => {
31             val label = sample.label
32             val feature = scaler.transform(sample.features)
33             new LabeledPoint(label, feature)
34         })
35         val splitRdd: Array[RDD[LabeledPoint]] = traindata.randomSplit(Array(1.0, 9.0))
36         val testData: RDD[LabeledPoint] = splitRdd(0)
37         val realTrainData: RDD[LabeledPoint] = splitRdd(1)
38 
39         val boostingStrategy: BoostingStrategy = BoostingStrategy.defaultParams("Classification")
40         boostingStrategy.setNumIterations(3)
41         boostingStrategy.treeStrategy.setNumClasses(2)
42         boostingStrategy.treeStrategy.setMaxDepth(args(2).toInt)
43         boostingStrategy.setLearningRate(0.8)
44         //  boostingStrategy.treeStrategy.setCategoricalFeaturesInfo(Map[Int, Int]())
45         val model = GradientBoostedTrees.train(realTrainData, boostingStrategy)
46 
47         val labelAndPreds = testData.map(point => {
48             val prediction = model.predict(point.features)
49             (point.label, prediction)
50         })
51         val acc = labelAndPreds.filter(r => r._1 == r._2).count.toDouble / testData.count()
52 
53         println("Test Error = " + acc)
54 
55         model.save(sc, args(1))
56     }
57 }
View Code
設置運行參數
  1. E:\IDEA_Projects\mlib\data\GBDT\train E:\IDEA_Projects\mlib\data\GBDT\train\model 10 local


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM