閑來無事最近復習了一下ID3決策樹算法,並憑着理解用pandas實現了一遍。對pandas更熟悉的朋友可供參考(鏈接如下)。相比本篇博文,更簡明清晰,更適合復習用。
https://github.com/DianeSoHungry/ShallowMachineLearningCodeItOut/blob/master/ID3.ipynb
現在要介紹的是ID3決策樹算法,只適用於標稱型數據,不適用於數值型數據。
決策樹學習算法最大的優點是,他可以自學習,在學習過程中,不需要使用者了解過多的背景知識、領域知識,只需要對訓練實例進行較好的標注就可以自學習了。
建立決策樹的關鍵在於當前狀態下選擇哪一個屬性作為分類依據,根據不同的目標函數,有三種主要的算法:
ID3(Iterative Dichotomiser)
C4.5
CART(Classification And Regression Tree)
問題描述:
下面是一個小型的數據集,5條記錄,2個特征(屬性),有標簽。
根據這個數據集,我們可以建立如下決策樹(用matplotlib的注釋功能畫的)。
觀察決策樹,決策節點為特征,其分支為決策節點的各個不同取值,葉節點為預測值。
建樹結束也就是建立好了一個決策樹分類器,有了分類器,就可以根據這個分類器對其他的魚進行預測了。預測准確性今天暫且不討論。
那么如何建立這樣的決策樹呢?
第一步:建立決策樹。
1.1 利用信息增益尋找當前最佳分類特征
想象現在你是一個判斷結點,你從頭頂的分支上獲得了一個數據集,表中包含標簽和若干屬性。你現在要根據某個屬性來對你接收到的數據集進行分組。到底用哪個屬性來作為划分依據呢?
我們用信息增益來選擇某個節點上用哪個特征來進行分類。
什么是信息?
如果待分類的事物可能划分在多個分類中,則每個分類xi的信息定義為:
(這里log前面應該有個負號。)
什么是香農熵?
香農熵是所有類別所有可能類別信息的期望值,即:
什么是信息增益?
信息增益=原香農熵-新香農熵
注意:新香農熵為按照某特征划分之后,每個分支數據集的香農熵之和。
可以這樣想:香農熵相當於數據類別(標簽)的混亂程度,信息增益可以衡量划分數據集前后數據(標簽)向有序性發展的程度。因此,回到怎樣利用信息增益尋找當前最佳分類特征的話題,假如你是一個判斷節點,你拿來一個數據集,數據集里面有若干個特征,你需要從中選取一個特征,使得信息增益最大(注意:將數據集中在該特征上取值相同的記錄划分到同一個分支,得到若干個分支數據集,每個分支數據集都有自己的香農熵,各個分支數據集的香農熵的期望才是新香農熵)。要找到這個特征只需要將數據集中的每個特征遍歷一次,求信息增益,取獲得最大信息增益的那個特征。
代碼如下(其中,calcShannonEnt(dataSet)函數用來計算數據集dataSet的香農熵,splitDataSet(dataSet, axis, value)函數將數據集dataSet的第axis列中特征值為value的記錄挑出來,組成分支數據集返回給函數。這兩個函數后面會給出函數定義。):
1 # 3-3 選擇最好的'數據集划分方式'(特征) 2 # 一個一個地試每個特征,如果某個按照某個特征分類得到的信息增益(原香農熵-新香農熵)最大, 3 # 則選這個特征作為最佳數據集划分方式 4 def chooseBestFeatureToSplit(dataSet): 5 numFeatures = len(dataSet[0]) - 1 6 baseEntropy = calcShannonEnt(dataSet) 7 bestInfoGain = 0.0 8 bestFeature = -1 9 for i in range(numFeatures): 10 featList = [example[i] for example in dataSet] 11 uniqueVals = set(featList) 12 newEntropy = 0.0 13 for value in uniqueVals: 14 subDataSet = splitDataSet(dataSet, i, value) 15 prob = len(subDataSet) / float(len(dataSet)) 16 newEntropy += prob * calcShannonEnt(subDataSet) 17 infoGain = baseEntropy - newEntropy 18 if (infoGain > bestInfoGain): 19 bestInfoGain = infoGain 20 bestFeature = i 21 return bestFeature
calcShannonEnt(dataSet)函數代碼:
1 def calcShannonEnt(dataSet): 2 numEntries = len(dataSet) # 總記錄數 3 labelCounts = {} # dataSet中所有出現過的標簽值為鍵,相應標簽值出現過的次數作為值 4 for featVec in dataSet: 5 currentLabel = featVec[-1] 6 labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1 7 shannonEnt = 0.0 8 for key in labelCounts: 9 prob = -float(labelCounts[key])/numEntries 10 shannonEnt += prob * log(prob, 2) 11 return shannonEnt
splitDataSet(dataSet, axis, value)函數代碼:
1 # 3-2 按照給定特征划分數據集(在某個特征axis上,值等於value的所有記錄 2 # 組成新的數據集retDataSet,新數據集不需要axis這個特征,注意value是特征值,axis指的是特征(所在的列下標)) 3 def splitDataSet(dataSet, axis, value): 4 retDataSet = [] 5 for featVec in dataSet: 6 if featVec[axis] == value: 7 reducedFeatVec = featVec[:axis] 8 reducedFeatVec.extend(featVec[axis+1:]) 9 retDataSet.append(reducedFeatVec) 10 return retDataSet
1.2 建樹
建樹是一個遞歸的過程。
遞歸結束的標志(判斷某節點是葉節點的標志):
情況1. 分到該節點的數據集中,所有記錄的標簽列取值都一樣。
或
情況2. 分到該節點的數據集中,只剩下標簽列。
a. 經判斷,若是葉節點,則:
對應情況1,返回數據集中第一條記錄的標簽值(反正所有標簽值都一樣)。
對應情況2,返回數據集中所有標簽值中,出現次數最多的那個標簽值(代碼中,定義一個函數majorityCnt(classList)來實現)
b. 經判斷,若不是葉節點,則:
step1. 建立一個字典,字典的鍵為該數據集上選出的最佳特征(划分依據)。
step2. 將具有相同特征值的記錄組成新的數據集(利用splitDataSet(dataSet, axis, value)函數實現,注意期間拋棄了當前用於划分數據的特征列),對新的數據集們進行遞歸建樹。
建樹代碼:
1 # 3-4 創建樹的函數代碼 2 # 如果非葉子結點,則以當前數據集建樹,並返回該樹。該樹的根節點是一個字典,鍵為划分當前數據集的最佳特征,值為按照鍵值划分后各個數據集構造的樹 3 # 葉子節點有兩種:1.只剩沒有特征時,葉子節點的返回值為所有記錄中,出現次數最多的那個標簽值 2.該葉子節點中,所有記錄的標簽相同。 4 5 def createTree(dataSet, labels): #label向量的維度為特征數,不是記錄數,是不同列下標對應的特征 6 classList = [example[-1] for example in dataSet] 7 if classList.count(classList[0]) == len(classList): 8 return classList[0] 9 if len(dataSet[0]) == 1: 10 return majorityCnt(classList) 11 bestFeat = chooseBestFeatureToSplit(dataSet) 12 bestFeatLabel = labels[bestFeat] 13 myTree = {bestFeatLabel: {}} 14 del(labels[bestFeat]) 15 featValues = [example[bestFeat] for example in dataSet] 16 uniqueVals = set(featValues) 17 for value in uniqueVals: #遞歸建子樹,若值為字典,則非葉節點,若為字符串,則為葉節點 18 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels) 19 return myTree
用上面給出的數據來建立一顆決策樹做示范:
在同一個程序中輸入如下代碼並運行:
1 def createDataSet(): 2 dataSet = [[1, 1, 'yes'], 3 [1, 1, 'yes'], 4 [1, 0, 'no'], 5 [0, 1, 'no'], 6 [0, 1, 'no']] 7 labels = ['no surfacing', 'flippers'] 8 return dataSet, labels 9 10 myDat, labels = createDataSet() 11 myTree = createTree(myDat, labels) 12 print myTree
運行結果為:
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
若利用后面畫決策樹的代碼可以畫出這顆決策樹:
案例:
我們通過建立決策樹來預測患者需要佩戴哪種隱形眼鏡(soft(軟材質)、hard(硬材質)、no lenses(不適合硬性眼睛)),數據集包含下面幾個特征:age(年齡), prescript(近視還是遠視), astigmatic(散光), tearRate(眼淚清除率)
建樹的結果為:
{'tearRate': {'reduced': 'no lenses', 'normal': {'astigmatic': {'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}}, 'myope': 'hard'}}, 'no': {'age': {'pre': 'soft', 'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}, 'young': 'soft'}}}}}}
畫出來是這個樣子:
畫決策樹的代碼(不講)
涉及matplotlib.pyplot模塊中的annotation的用法,點擊鏈接進入官網學習這塊內容的prerequisite。
1 # _*_coding:utf-8_*_ 2 3 # 3-7 plotTree函數 4 import matplotlib.pyplot as plt 5 6 # 定義節點和箭頭格式的常量 7 decisionNode = dict(boxstyle="sawtooth", fc="0.8") 8 leafNode = dict(boxstyle="round4", fc="0.8") 9 arrow_args = dict(arrowstyle="<-") 10 11 12 def plotMidTest(cntrPt, parentPt,txtString): 13 xMid = (parentPt[0] + cntrPt[0])/2.0 14 yMid = (parentPt[1] + cntrPt[1])/2.0 15 createPlot.ax1.text(xMid, yMid, txtString) 16 17 # 繪制自身 18 # 若當前子節點不是葉子節點,遞歸 19 # 若當子節點為葉子節點,繪制該節點 20 def plotTree(myTree, parentPt, nodeTxt): 21 numLeafs = getNumLeafs(myTree) 22 # depth = getTreeDepth(myTree) 23 firstStr = myTree.keys()[0] 24 cntrPt = (plotTree.xoff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yoff) 25 plotMidTest(cntrPt, parentPt, nodeTxt) 26 plotNode(firstStr, cntrPt, parentPt, decisionNode) 27 secondDict = myTree[firstStr] 28 plotTree.yoff = plotTree.yoff - 1.0/plotTree.totalD 29 for key in secondDict.keys(): 30 if type(secondDict[key]).__name__=='dict': 31 plotTree(secondDict[key], cntrPt, str(key)) 32 else: 33 plotTree.xoff = plotTree.xoff + 1.0/plotTree.totalW 34 plotNode(secondDict[key], (plotTree.xoff, plotTree.yoff), cntrPt, leafNode) 35 plotMidTest((plotTree.xoff, plotTree.yoff), cntrPt, str(key)) 36 plotTree.yoff = plotTree.yoff + 1.0/plotTree.totalD 37 38 39 # figure points 40 # 畫結點的模板 41 def plotNode(nodeTxt, centerPt, parentPt, nodeType): 42 createPlot.ax1.annotate(nodeTxt, # 注釋的文字,(一個字符串) 43 xy=parentPt, # 被注釋的地方(一個坐標) 44 xycoords='axes fraction', # xy所用的坐標系 45 xytext=centerPt, # 插入文本的地方(一個坐標) 46 textcoords='axes fraction', # xytext所用的坐標系 47 va="center", 48 ha="center", 49 bbox=nodeType, # 注釋文字用的框的格式 50 arrowprops=arrow_args) # 箭頭屬性 51 52 53 def createPlot(inTree): 54 fig = plt.figure(1, facecolor='white') 55 fig.clf() 56 axprops = dict(xticks=[], yticks=[]) 57 createPlot.ax1 = plt.subplot(111,frameon=False, **axprops) 58 plotTree.totalW = float(getNumLeafs(inTree)) 59 plotTree.totalD = float(getTreeDepth(inTree)) 60 plotTree.xoff = -0.5/plotTree.totalW 61 plotTree.yoff = 1.0 62 63 plotTree(inTree, (0.5, 1.0),'') #樹的引用作為父節點,但不畫出來,所以用'' 64 plt.show() 65 66 def getNumLeafs(myTree): 67 numLeafs = 0 68 firstStr = myTree.keys()[0] 69 secondDict = myTree[firstStr] 70 for key in secondDict.keys(): 71 if type(secondDict[key]).__name__ =='dict': 72 numLeafs += getNumLeafs(secondDict[key]) 73 else: 74 numLeafs += 1 75 return numLeafs 76 77 # 子樹中樹高最大的那一顆的高度+1作為當前數的高度 78 def getTreeDepth(myTree): 79 maxDepth = 0 #用來記錄最高子樹的高度+1 80 firstStr = myTree.keys()[0] 81 secondDict = myTree[firstStr] 82 for key in secondDict.keys(): 83 if type(secondDict[key]).__name__ == 'dict': 84 thisDepth = 1 + getTreeDepth(secondDict[key]) 85 else: 86 thisDepth = 1 87 if(thisDepth > maxDepth): 88 maxDepth = thisDepth 89 return maxDepth 90 91 # 方便測試用的人造測試樹 92 def retrieveTree(i): 93 listofTrees = [{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}}, 94 {'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}} 95 ] 96 return listofTrees[i]