決策樹原理實例(python代碼實現)
決策數(Decision Tree)在機器學習中也是比較常見的一種算法,屬於監督學習中的一種。看字面意思應該也比較容易理解,相比其他算法比如支持向量機(SVM)或神經網絡,似乎決策樹感覺“親切”許多。
- 優點:計算復雜度不高,輸出結果易於理解,對中間值的缺失值不敏感,可以處理不相關特征數據。
- 缺點:可能會產生過度匹配的問題。
- 使用數據類型:數值型和標稱型。
簡單介紹完畢,讓我們來通過一個例子讓決策樹“原形畢露”。
一天,老師問了個問題,只根據頭發和聲音怎么判斷一位同學的性別。
為了解決這個問題,同學們馬上簡單的統計了7位同學的相關特征,數據如下:
頭發 | 聲音 | 性別 |
---|---|---|
長 | 粗 | 男 |
短 | 粗 | 男 |
短 | 粗 | 男 |
長 | 細 | 女 |
短 | 細 | 女 |
短 | 粗 | 女 |
長 | 粗 | 女 |
長 | 粗 | 女 |
機智的同學A想了想,先根據頭發判斷,若判斷不出,再根據聲音判斷,於是畫了一幅圖,如下:
於是,一個簡單、直觀的決策樹就這么出來了。頭發長、聲音粗就是男生;頭發長、聲音細就是女生;頭發短、聲音粗是男生;頭發短、聲音細是女生。
原來機器學習中決策樹就這玩意,這也太簡單了吧。。。
這時又蹦出個同學B,想先根據聲音判斷,然后再根據頭發來判斷,如是大手一揮也畫了個決策樹:
同學B的決策樹:首先判斷聲音,聲音細,就是女生;聲音粗、頭發長是男生;聲音粗、頭發長是女生。
那么問題來了:同學A和同學B誰的決策樹好些?計算機做決策樹的時候,面對多個特征,該如何選哪個特征為最佳的划分特征?
划分數據集的大原則是:將無序的數據變得更加有序。
我們可以使用多種方法划分數據集,但是每種方法都有各自的優缺點。於是我們這么想,如果我們能測量數據的復雜度,對比按不同特征分類后的數據復雜度,若按某一特征分類后復雜度減少的更多,那么這個特征即為最佳分類特征。
Claude Shannon 定義了熵(entropy)和信息增益(information gain)。
用熵來表示信息的復雜度,熵越大,則信息越復雜。公式如下:
信息增益(information gain),表示兩個信息熵的差值。
首先計算未分類前的熵,總共有8位同學,男生3位,女生5位。
熵(總)=-3/8*log2(3/8)-5/8*log2(5/8)=0.9544
接着分別計算同學A和同學B分類后信息熵。
同學A首先按頭發分類,分類后的結果為:長頭發中有1男3女。短頭發中有2男2女。
熵(同學A長發)=-1/4*log2(1/4)-3/4*log2(3/4)=0.8113
熵(同學A短發)=-2/4*log2(2/4)-2/4*log2(2/4)=1
熵(同學A)=4/8*0.8113+4/8*1=0.9057
信息增益(同學A)=熵(總)-熵(同學A)=0.9544-0.9057=0.0487
同理,按同學B的方法,首先按聲音特征來分,分類后的結果為:聲音粗中有3男3女。聲音細中有0男2女。
熵(同學B聲音粗)=-3/6*log2(3/6)-3/6*log2(3/6)=1
熵(同學B聲音粗)=-2/2*log2(2/2)=0
熵(同學B)=6/8*1+2/8*0=0.75
信息增益(同學B)=熵(總)-熵(同學A)=0.9544-0.75=0.2087
按同學B的方法,先按聲音特征分類,信息增益更大,區分樣本的能力更強,更具有代表性。
以上就是決策樹ID3算法的核心思想。
接下來用python代碼來實現ID3算法:
1 from math import log 2 import operator 3 4 def calcShannonEnt(dataSet): # 計算數據的熵(entropy) 5 numEntries=len(dataSet) # 數據條數 6 labelCounts={} 7 for featVec in dataSet: 8 currentLabel=featVec[-1] # 每行數據的最后一個字(類別) 9 if currentLabel not in labelCounts.keys(): 10 labelCounts[currentLabel]=0 11 labelCounts[currentLabel]+=1 # 統計有多少個類以及每個類的數量 12 shannonEnt=0 13 for key in labelCounts: 14 prob=float(labelCounts[key])/numEntries # 計算單個類的熵值 15 shannonEnt-=prob*log(prob,2) # 累加每個類的熵值 16 return shannonEnt 17 18 def createDataSet1(): # 創造示例數據 19 dataSet = [['長', '粗', '男'], 20 ['短', '粗', '男'], 21 ['短', '粗', '男'], 22 ['長', '細', '女'], 23 ['短', '細', '女'], 24 ['短', '粗', '女'], 25 ['長', '粗', '女'], 26 ['長', '粗', '女']] 27 labels = ['頭發','聲音'] #兩個特征 28 return dataSet,labels 29 30 def splitDataSet(dataSet,axis,value): # 按某個特征分類后的數據 31 retDataSet=[] 32 for featVec in dataSet: 33 if featVec[axis]==value: 34 reducedFeatVec =featVec[:axis] 35 reducedFeatVec.extend(featVec[axis+1:]) 36 retDataSet.append(reducedFeatVec) 37 return retDataSet 38 39 def chooseBestFeatureToSplit(dataSet): # 選擇最優的分類特征 40 numFeatures = len(dataSet[0])-1 41 baseEntropy = calcShannonEnt(dataSet) # 原始的熵 42 bestInfoGain = 0 43 bestFeature = -1 44 for i in range(numFeatures): 45 featList = [example[i] for example in dataSet] 46 uniqueVals = set(featList) 47 newEntropy = 0 48 for value in uniqueVals: 49 subDataSet = splitDataSet(dataSet,i,value) 50 prob =len(subDataSet)/float(len(dataSet)) 51 newEntropy +=prob*calcShannonEnt(subDataSet) # 按特征分類后的熵 52 infoGain = baseEntropy - newEntropy # 原始熵與按特征分類后的熵的差值 53 if (infoGain>bestInfoGain): # 若按某特征划分后,熵值減少的最大,則次特征為最優分類特征 54 bestInfoGain=infoGain 55 bestFeature = i 56 return bestFeature 57 58 def majorityCnt(classList): #按分類后類別數量排序,比如:最后分類為2男1女,則判定為男; 59 classCount={} 60 for vote in classList: 61 if vote not in classCount.keys(): 62 classCount[vote]=0 63 classCount[vote]+=1 64 sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) 65 return sortedClassCount[0][0] 66 67 def createTree(dataSet,labels): 68 classList=[example[-1] for example in dataSet] # 類別:男或女 69 if classList.count(classList[0])==len(classList): 70 return classList[0] 71 if len(dataSet[0])==1: 72 return majorityCnt(classList) 73 bestFeat=chooseBestFeatureToSplit(dataSet) #選擇最優特征 74 bestFeatLabel=labels[bestFeat] 75 myTree={bestFeatLabel:{}} #分類結果以字典形式保存 76 del(labels[bestFeat]) 77 featValues=[example[bestFeat] for example in dataSet] 78 uniqueVals=set(featValues) 79 for value in uniqueVals: 80 subLabels=labels[:] 81 myTree[bestFeatLabel][value]=createTree(splitDataSet\ 82 (dataSet,bestFeat,value),subLabels) 83 return myTree 84 85 86 if __name__=='__main__': 87 dataSet, labels=createDataSet1() # 創造示列數據 88 print(createTree(dataSet, labels)) # 輸出決策樹模型結果
輸出結果為:
1 {'聲音': {'細': '女', '粗': {'頭發': {'短': '男', '長': '女'}}}}
這個結果的意思是:首先按聲音分類,聲音細為女生;然后再按頭發分類:聲音粗,頭發短為男生;聲音粗,頭發長為女生。
這個結果也正是同學B的結果。
補充說明:判定分類結束的依據是,若按某特征分類后出現了最終類(男或女),則判定分類結束。使用這種方法,在數據比較大,特征比較多的情況下,很容易造成過擬合,於是需進行決策樹枝剪,一般枝剪方法是當按某一特征分類后的熵小於設定值時,停止分類。
ID3算法存在的缺點:
1. ID3算法在選擇根節點和內部節點中的分支屬性時,采用信息增益作為評價標准。信息增益的缺點是傾向於選擇取值較多是屬性,在有些情況下這類屬性可能不會提供太多有價值的信息。
2. ID3算法只能對描述屬性為離散型屬性的數據集構造決策樹 。
為了改進決策樹,又提出了ID4.5算法和CART算法。之后有時間會介紹這兩種算法。
參考:
- Machine Learning in Action
- 統計學習方法
轉載:http://blog.csdn.net/csqazwsxedc/article/details/65697652