決策樹算法原理及實現


(一)認識決策樹

1、決策樹分類原理

   決策樹是通過一系列規則對數據進行分類的過程。它提供一種在什么條件下會得到什么值的類似規則的方法。決策樹分為分類樹和回歸樹兩種,分類樹對離散變量做決策樹,回歸樹對連續變量做決策樹。

  近來的調查表明決策樹也是最經常使用的數據挖掘算法,它的概念非常簡單。決策樹算法之所以如此流行,一個很重要的原因就是使用者基本上不用了解機器學習算法,也不用深究它是如何工作的。直觀看上去,決策樹分類器就像判斷模塊和終止塊組成的流程圖,終止塊表示分類結果(也就是樹的葉子)。判斷模塊表示對一個特征取值的判斷(該特征有幾個值,判斷模塊就有幾個分支)。

  如果不考慮效率等,那么樣本所有特征的判斷級聯起來終會將某一個樣本分到一個類終止塊上。實際上,樣本所有特征中有一些特征在分類時起到決定性作用,決策樹的構造過程就是找到這些具有決定性作用的特征,根據其決定性程度來構造一個倒立的樹--決定性作用最大的那個特征作為根節點,然后遞歸找到各分支下子數據集中次大的決定性特征,直至子數據集中所有數據都屬於同一類。所以,構造決策樹的過程本質上就是根據數據特征將數據集分類的遞歸過程,我們需要解決的第一個問題就是,當前數據集上哪個特征在划分數據分類時起決定性作用。

2、決策樹的學習過程

一棵決策樹的生成過程主要分為以下3個部分:

  • 特征選擇:特征選擇是指從訓練數據中眾多的特征中選擇一個特征作為當前節點的分裂標准,如何選擇特征有着很多不同量化評估標准標准,從而衍生出不同的決策樹算法。

  • 決策樹生成: 根據選擇的特征評估標准,從上至下遞歸地生成子節點,直到數據集不可分則停止決策樹停止生長。 樹結構來說,遞歸結構是最容易理解的方式。

  • 剪枝:決策樹容易過擬合,一般來需要剪枝,縮小樹結構規模、緩解過擬合。剪枝技術有預剪枝和后剪枝兩種。

3、基於信息論的三種決策樹算法

  划分數據集的最大原則是:使無序的數據變的有序。如果一個訓練數據中有20個特征,那么選取哪個做划分依據?這就必須采用量化的方法來判斷,量化划分方法有多重,其中一項就是“信息論度量信息分類”。基於信息論的決策樹算法有ID3、CART和C4.5等算法,其中C4.5和CART兩種算法從ID3算法中衍生而來。

   CART和C4.5支持數據特征為連續分布時的處理,主要通過使用二元切分來處理連續型變量,即求一個特定的值-分裂值:特征值大於分裂值就走左子樹,或者就走右子樹。這個分裂值的選取的原則是使得划分后的子樹中的“混亂程度”降低,具體到C4.5和CART算法則有不同的定義方式。

  ID3算法由Ross Quinlan發明,建立在“奧卡姆剃刀”的基礎上:越是小型的決策樹越優於大的決策樹(be simple簡單理論)。ID3算法中根據信息論的信息增益評估和選擇特征,每次選擇信息增益最大的特征做判斷模塊。ID3算法可用於划分標稱型數據集,沒有剪枝的過程,為了去除過度數據匹配的問題,可通過裁剪合並相鄰的無法產生大量信息增益的葉子節點(例如設置信息增益閥值)。使用信息增益的話其實是有一個缺點,那就是它偏向於具有大量值的屬性--就是說在訓練集中,某個屬性所取的不同值的個數越多,那么越有可能拿它來作為分裂屬性,而這樣做有時候是沒有意義的,另外ID3不能處理連續分布的數據特征,於是就有了C4.5算法。CART算法也支持連續分布的數據特征。

  C4.5是ID3的一個改進算法,繼承了ID3算法的優點。C4.5算法用信息增益率來選擇屬性,克服了用信息增益選擇屬性時偏向選擇取值多的屬性的不足在樹構造過程中進行剪枝;能夠完成對連續屬性的離散化處理;能夠對不完整數據進行處理。C4.5算法產生的分類規則易於理解、准確率較高;但效率低,因樹構造過程中,需要對數據集進行多次的順序掃描和排序。也是因為必須多次數據集掃描,C4.5只適合於能夠駐留於內存的數據集。

  CART算法的全稱是Classification And Regression Tree,采用的是Gini指數(選Gini指數最小的特征s)作為分裂標准,同時它也是包含后剪枝操作。ID3算法和C4.5算法雖然在對訓練樣本集的學習中可以盡可能多地挖掘信息,但其生成的決策樹分支較大,規模較大。為了簡化決策樹的規模,提高生成決策樹的效率,就出現了根據GINI系數來選擇測試屬性的決策樹算法CART。

4、決策樹優缺點

  決策樹適用於數值型和標稱型(離散型數據,變量的結果只在有限目標集中取值),能夠讀取數據集合,提取一些列數據中蘊含的規則。在分類問題中使用決策樹模型有很多的優點,決策樹計算復雜度不高、便於使用、而且高效,決策樹可處理具有不相關特征的數據、可很容易地構造出易於理解的規則,而規則通常易於解釋和理解。決策樹模型也有一些缺點,比如處理缺失數據時的困難、過度擬合以及忽略數據集中屬性之間的相關性等。

(二)ID3算法的數學原理

  前面已經提到C4.5和CART都是由ID3演化而來,這里就先詳細闡述ID3算法,奠下基礎。

1、ID3算法的信息論基礎

  關於決策樹的信息論基礎可以參考“決策樹1-建模過程”

(1)信息熵

  信息熵:在概率論中,信息熵給了我們一種度量不確定性的方式,是用來衡量隨機變量不確定性的,熵就是信息的期望值。若待分類的事物可能划分在N類中,分別是x1,x2,……,xn,每一種取到的概率分別是P1,P2,……,Pn,那么X的熵就定義為:

,從定義中可知:0≤H(X)≤log(n)

  當隨機變量只取兩個值時,即X的分布為 P(X=1)=p,X(X=0)=1−p,0≤p≤1則熵為:H(X)=−plog2(p)−(1−p)log2(1−p)。

熵值越高,則數據混合的種類越高,其蘊含的含義是一個變量可能的變化越多(反而跟變量具體的取值沒有任何關系,只和值的種類多少以及發生概率有關),它攜帶的信息量就越大。熵在信息論中是一個非常重要的概念,很多機器學習的算法都會利用到這個概念。

(2)條件熵

  假設有隨機變量(X,Y),其聯合概率分布為:P(X=xi,Y=yi)=pij,i=1,2,⋯,n;j=1,2,⋯,m

  則條件熵(H(Y∣X))表示在已知隨機變量X的條件下隨機變量Y的不確定性,其定義為X在給定條件下Y的條件概率分布的熵對X的數學期望:

     
 

(3)信息增益

  信息增益(information gain)表示得知特征X的信息后,而使得Y的不確定性減少的程度。定義為:

     

2、ID3算法推導

(1)分類系統信息熵

  假設一個分類系統的樣本空間(D,Y),D表示樣本(有m個特征),Y表示n個類別,可能的取值是C1,C2,...,Cn。每一個類別出現的概率是P(C1),P(C2),...,P(Cn)。該分類系統的熵為:

  離散分布中,類別Ci出現的概率P(Ci),通過該類別出現的次數除去樣本總數即可得到。對於連續分布,常需要分塊做離散化處理獲得。

(2)條件熵

  根據條件熵的定義,分類系統中的條件熵指的是當樣本的某一特征X固定時的信息熵。由於該特征X可能的取值會有(x1,x2,……,xn),當計算條件熵而需要把它固定的時候,每一種可能都要固定一下,然后求統計期望。

  因此樣本特征X取值為xi的概率是Pi,該特征被固定為值xi時的條件信息熵就是H(C|X=xi),那么

  H(C|X)就是分類系統中特征X被固定時的條件熵(X=(x1,x2,……,xn)):

  若是樣本的該特征只有兩個值(x1 = 0,x2=1)對應(出現,不出現),如文本分類中某一個單詞的出現與否。那么對於特征二值的情況,我們用T代表特征,用t代表T出現,表示該特征出現。那么:

 

  與前面條件熵的公式對比一下,P(t)就是T出現的概率,就是T不出現的概率。結合信息熵的計算公式,可得:

     

  特征T出現的概率P(t),只要用出現過T的樣本數除以總樣本數就可以了;P(Ci|t)表示出現T的時候,類別Ci出現的概率,只要用出現了T並且屬於類別Ci的樣本數除以出現了T的樣本數就得到了。

(3)信息增益

  根據信息增益的公式,分類系統中特征X的信息增益就是:Gain(D, X) = H(C)-H(C|X)

  信息增益是針對一個一個的特征而言的,就是看一個特征X,系統有它和沒它的時候信息量各是多少,兩者的差值就是這個特征給系統帶來的信息增益。每次選取特征的過程都是通過計算每個特征值划分數據集后的信息增益,然后選取信息增益最高的特征。

  對於特征取值為二值的情況,特征T給系統帶來的信息增益就可以寫成系統原本的熵與固定特征T后的條件熵之差:

(4)經過上述一輪信息增益計算后會得到一個特征作為決策樹的根節點,該特征有幾個取值,根節點就會有幾個分支,每一個分支都會產生一個新的數據子集Dk,余下的遞歸過程就是對每個Dk再重復上述過程,直至子數據集都屬於同一類。

  在決策樹構造過程中可能會出現這種情況:所有特征都作為分裂特征用光了,但子集還不是純凈集(集合內的元素不屬於同一類別)。在這種情況下,由於沒有更多信息可以使用了,一般對這些子集進行“多數表決”,即使用此子集中出現次數最多的類別作為此節點類別,然后將此節點作為葉子節點。

 

(三)C4.5算法

1、信息增益比選擇最佳特征

  以信息增益進行分類決策時,存在偏向於取值較多的特征的問題。於是為了解決這個問題人們有開發了基於信息增益比的分類決策方法,也就是C4.5。C4.5與ID3都是利用貪心算法進行求解,不同的是分類決策的依據不同。

  因此,C4.5算法在結構與遞歸上與ID3完全相同,區別就在於選取決斷特征時選擇信息增益比最大的。

  信息增益比率度量是用ID3算法中的的增益度量Gain(D,X)和分裂信息度量SplitInformation(D,X)來共同定義的。分裂信息度量SplitInformation(D,X)就相當於特征X(取值為x1,x2,……,xn,各自的概率為P1,P2,...,Pn,Pk就是樣本空間中特征X取值為xk的數量除上該樣本空間總數)的熵。

  SplitInformation(D,X) = -P1 log2(P1)-P2 log2(P)-,...,-Pn log2(Pn)

  GainRatio(D,X) = Gain(D,X)/SplitInformation(D,X)

  在ID3中用信息增益選擇屬性時偏向於選擇分枝比較多的屬性值,即取值多的屬性,在C4.5中由於除以SplitInformation(D,X)=H(X),可以削弱這種作用。

2、處理連續數值型特征

  C4.5既可以處理離散型屬性,也可以處理連續性屬性。在選擇某節點上的分枝屬性時,對於離散型描述屬性,C4.5的處理方法與ID3相同。對於連續分布的特征,其處理方法是:

  先把連續屬性轉換為離散屬性再進行處理。雖然本質上屬性的取值是連續的,但對於有限的采樣數據它是離散的,如果有N條樣本,那么我們有N-1種離散化的方法:<=vj的分到左子樹,>vj的分到右子樹。計算這N-1種情況下最大的信息增益率。另外,對於連續屬性先進行排序(升序),只有在決策屬性(即分類發生了變化)發生改變的地方才需要切開,這可以顯著減少運算量。經證明,在決定連續特征的分界點時采用增益這個指標(因為若采用增益率,splittedinfo影響分裂點信息度量准確性,若某分界點恰好將連續特征分成數目相等的兩部分時其抑制作用最大),而選擇屬性的時候才使用增益率這個指標能選擇出最佳分類特征。

  在C4.5中,對連續屬性的處理如下:

    1、對特征的取值進行升序排序

    2、兩個特征取值之間的中點作為可能的分裂點,將數據集分成兩部分,計算每個可能的分裂點的信息增益(InforGain)。優化算法就是只計算分類屬性發生改變的那些特征取值。

    3、選擇修正后信息增益(InforGain)最大的分裂點作為該特征的最佳分裂點

    4、計算最佳分裂點的信息增益率(Gain Ratio)作為特征的Gain Ratio。注意,此處需對最佳分裂點的信息增益進行修正:減去log2(N-1)/|D|(N是連續特征的取值個數,D是訓練數據數目,此修正的原因在於:當離散屬性和連續屬性並存時,C4.5算法傾向於選擇連續特征做最佳樹分裂點)

3、葉子裁剪

  分析分類回歸樹的遞歸建樹過程,不難發現它實質上存在着一個數據過度擬合問題。在決策樹構造時,由於訓練數據中的噪音或孤立點,許多分枝反映的是訓練數據中的異常,使用這樣的判定樹對類別未知的數據進行分類,分類的准確性不高。因此試圖檢測和減去這樣的分支,檢測和減去這些分支的過程被稱為樹剪枝。樹剪枝方法用於處理過分適應數據問題。通常,這種方法使用統計度量,減去最不可靠的分支,這將導致較快的分類,提高樹獨立於訓練數據正確分類的能力。

  決策樹常用的剪枝常用的簡直方法有兩種:預剪枝(Pre-Pruning)和后剪枝(Post-Pruning)。預剪枝是根據一些原則及早的停止樹增長,如樹的深度達到用戶所要的深度、節點中樣本個數少於用戶指定個數、不純度指標下降的最大幅度小於用戶指定的幅度等。預剪枝的核心問題是如何事先指定樹的最大深度,如果設置的最大深度不恰當,那么將會導致過於限制樹的生長,使決策樹的表達式規則趨於一般,不能更好地對新數據集進行分類和預測。除了事先限定決策樹的最大深度之外,還有另外一個方法來實現預剪枝操作,那就是采用檢驗技術對當前結點對應的樣本集合進行檢驗,如果該樣本集合的樣本數量已小於事先指定的最小允許值,那么停止該結點的繼續生長,並將該結點變為葉子結點,否則可以繼續擴展該結點。

  后剪枝則是通過在完全生長的樹上剪去分枝實現的,通過刪除節點的分支來剪去樹節點,可以使用的后剪枝方法有多種,比如:代價復雜性剪枝、最小誤差剪枝、悲觀誤差剪枝等等。后剪枝操作是一個邊修剪邊檢驗的過程,一般規則標准是:在決策樹的不斷剪枝操作過程中,將原樣本集合或新數據集合作為測試數據,檢驗決策樹對測試數據的預測精度,並計算出相應的錯誤率,如果剪掉某個子樹后的決策樹對測試數據的預測精度或其他測度不降低,那么剪掉該子樹。

(四)決策樹python版實現

訓練集:

sunny   hot     high    false   N
sunny   hot     high    true    N
overcast        hot     high    false   Y
rain    mild    high    false   Y
rain    cool    normal  false   Y
rain    cool    normal  true    N
overcast        cool    normal  true    Y

測試集:

sunny   mild    high    false
sunny   cool    normal  false
rain    mild    normal  false
sunny   mild    normal  true
overcast        mild    high    true
overcast        hot     normal  false
rain    mild    high    true

code:

  1 #!/usr/bin/python
  2 import sys
  3 import copy
  4 import math
  5 import getopt
  6 
  7 def usage():
  8     print '''Help Information:
  9     -h, --help: show help information;
 10     -r, --train: train file;
 11     -t, --test: test file;
 12     '''
 13 
 14 def getparamenter():
 15     try:
 16       opts, args = getopt.getopt(sys.argv[1:], "hr:t:k:", ["help", "train=","test=","kst="])
 17     except getopt.GetoptError, err:
 18       print str(err)
 19       usage()
 20       sys.exit(1)
 21 
 22     sys.stderr.write("\ntrain.py : a python script for perception training.\n")
 23     sys.stderr.write("Copyright 2016 sxron, search, Sogou. \n")
 24     sys.stderr.write("Email: shixiang08abc@gmail.com \n\n")
 25 
 26     train = ''
 27     test = ''
 28     for i, f in opts:
 29       if i in ("-h", "--help"):
 30         usage()
 31         sys.exit(1)
 32       elif i in ("-r", "--train"):
 33         train = f
 34       elif i in ("-t", "--test"):
 35         test = f
 36       else:
 37         assert False, "unknown option"
 38   
 39     print "start trian parameter \ttrain:%s\ttest:%s" % (train,test)
 40 
 41     return train,test
 42 
 43 def loaddata(file):
 44     fin = open(file,'r')
 45     data = []
 46     while 1:
 47         dataline = []
 48         line = fin.readline()
 49         if not line:
 50             break
 51         ts = line.strip().split('\t')
 52         for temp in ts:
 53             dataline.append(temp.strip())
 54         data.append(dataline)
 55     return data
 56 
 57 def majorityCnt(classList):
 58     classCnt = {}
 59     for cls in classList:
 60         if not classCnt.has_key(cls):
 61             classCnt[cls] = 0
 62         classCnt[cls] += 1
 63 
 64     SortClassCnt = sorted(classCnt.iteritems(),key=lambda d:d[1],reverse=True)
 65     return SortClassCnt[0][0]
 66 
 67 def calcShannonEnt(trainData):
 68     numEntries = len(trainData)
 69     labelDic = {}
 70     for trainLine in trainData:
 71         currentLabel = trainLine[-1]
 72         if not labelDic.has_key(currentLabel):
 73             labelDic[currentLabel] = 0
 74         labelDic[currentLabel] += 1
 75 
 76     shannonEnt = 0.0
 77     for key,value in labelDic.items():
 78         prob = float(value)/numEntries
 79         shannonEnt -= prob * math.log(prob,2)
 80     return shannonEnt
 81 
 82 def splitData(trainData,index,value):
 83     subData = []
 84     for trainLine in trainData:
 85         if trainLine[index]==value:
 86             reducedFeatVec = []
 87             for i in range(0,len(trainLine),1):
 88                 if i==index:
 89                     continue
 90                 reducedFeatVec.append(trainLine[i])
 91             subData.append(reducedFeatVec)
 92     return subData
 93 
 94 def chooseBestFeature(trainData):
 95     numFeatures = len(trainData[0])-1
 96     baseEntropy = calcShannonEnt(trainData)
 97     bestInfoGain = 0.0
 98     bestFeature = -1
 99     for i in range(0,numFeatures,1):
100         currentFeature = [temp[i] for temp in trainData]
101         uniqueValues = set(currentFeature)
102         newEntropy = 0.0
103         splitInfo = 0.0
104         for value in uniqueValues:
105             subData = splitData(trainData,i,value)
106             prob = float(len(subData))/len(trainData)
107             newEntropy += prob * calcShannonEnt(subData)
108             splitInfo -= prob * math.log(prob,2)
109         infoGain = (baseEntropy - newEntropy) / splitInfo
110         if infoGain > bestInfoGain :
111             bestInfoGain = infoGain
112             bestFeature = i
113     return bestFeature
114 
115 def CreateTree(trainData):
116     classList = [temp[-1] for temp in trainData]
117     classListSet = set(classList)
118     if len(classListSet)==1:
119         return classList[0]
120     if len(trainData[0])==1:
121         return majorityCnt(classList)
122 
123     bestFeature = chooseBestFeature(trainData)
124     myTree = {bestFeature:{}}
125     featureValues = [example[bestFeature] for example in trainData]
126     uniqueValues = set(featureValues)
127     for value in uniqueValues:
128         myTree[bestFeature][value] = CreateTree(splitData(trainData, bestFeature, value))
129     return myTree
130 
131 def classify(testData,dTrees):
132     index = int(dTrees.keys()[0])
133     secondDict = dTrees[index]
134     testValue = testData[index]
135     for key in secondDict.keys():
136         if testValue==key:
137             if type(secondDict[key]).__name__=='dict':
138                 secondTest = copy.deepcopy(testData)
139                 del secondTest[index]
140                 classLabel = classify(secondTest,secondDict[key])
141             else:
142                 classLabel = secondDict[key]
143     return classLabel
144 
145 def TestFunc(testData,dTrees):
146     for temp in testData:
147         classLabel = classify(temp,dTrees)
148         print "%s\t%s" % (temp,classLabel)
149 
150 def main():
151     #set parameter
152     train,test = getparamenter()
153 
154     #load train data
155     trainData = loaddata(train)
156     testData = loaddata(test)
157 
158     #create Decision Tree
159     dTrees = CreateTree(trainData)
160     print dTrees
161 
162     #test Decision Tree
163     TestFunc(testData,dTrees) 
164 
165 if __name__=="__main__":
166     main()

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM