決策樹之系列一ID3原理與代碼實現
本文系作者原創,轉載請注明出處:https://www.cnblogs.com/further-further-further/p/9429257.html
應用實例:
你是否玩過二十個問題的游戲,游戲的規則很簡單:參與游戲的一方在腦海里想某個事物,
其他參與者向他提問題,只允許提20個問題,問題的答案也只能用對或錯回答。問問題的人通過
推斷分解,逐步縮小待猜測事物的范圍。決策樹的工作原理與20個問題類似,用戶輸人一系列數
據,然后給出游戲的答案。如下表
假如我告訴你,我有一個海洋生物,它不浮出水面可以生存,並且沒有腳蹼,你來判斷一下是否屬於魚類?
通過決策樹,你就可以快速給出答案不是魚類。
決策樹的目的就是在一大堆無序的數據特征中找出有序的規則,並建立決策樹(模型)。
決策樹比較文縐縐的介紹
決策樹學習是一種逼近離散值目標函數的方法。通過將一組數據中學習的函數表示為決策樹,從而將大量數據有目的的分類,從而找到潛在有價值的信息。決策樹分類通常分為兩步---生成樹和剪枝;
樹的生成 --- 自上而下的遞歸分治法;
剪枝 --- 剪去那些可能增大錯誤預測率的分枝。
決策樹的方法起源於概念學習系統CLS(Concept Learning System), 然后發展最具有代表性的ID3(以信息熵作為目標評價函數)算法,最后又演化為C4.5, C5.0,CART可以處理連續屬性。
這篇文章主要介紹ID3算法原理與代碼實現(屬於分類算法)
分類與回歸的區別
回歸問題和分類問題的本質一樣,都是針對一個輸入做出一個輸出預測,其區別在於輸出變量的類型。
分類問題是指,給定一個新的模式,根據訓練集推斷它所對應的類別(如:+1,-1),是一種定性輸出,也叫離散變量預測;
回歸問題是指,給定一個新的模式,根據訓練集推斷它所對應的輸出值(實數)是多少,是一種定量輸出,也叫連續變量預測。
舉個例子:預測明天的氣溫是多少度,這是一個回歸任務;預測明天是陰、晴還是雨,就是一個分類任務。
分類模型可將回歸模型的輸出離散化,回歸模型也可將分類模型的輸出連續化。
信息論相關知識
來自王小猴<<機器學習實戰>>學習總結(二)------決策樹算法(https://zhuanlan.zhihu.com/p/29980400),他將原理說得很透徹形象,這里借鑒一下。
1. 信息熵
在決策樹算法中,熵是一個非常非常重要的概念。
一件事發生的概率越小,我們說它所蘊含的信息量越大。
比如:我們聽女人能懷孕不奇怪,如果某天聽到哪個男人懷孕了,那這個信息量就很大了......。
所以我們這樣衡量信息量:
其中,P(y)是事件發生的概率。
信息熵就是所有可能發生的事件的信息量的期望:
表達了Y事件發生的不確定度。
2. 條件熵
表示在X給定條件下,Y的條件概率分布的熵對X的數學期望。其數學推導如下:
舉個例子
例:女生決定主不主動追一個男生的標准有兩個:顏值和身高,如下表所示:

上表中隨機變量Y={追,不追},P(Y=追)=2/3,P(Y=不追)=1/3,得到Y的熵:

這里還有一個特征變量X,X={高,不高}。當X=高時,追的個數為1,占1/2,不追的個數為1,占1/2,此時:

同理:

(注意:我們一般約定,當p=0時,plogp=0)
所以我們得到條件熵的計算公式:


決策樹算法
1. 算法簡介
決策樹算法是一類常見的分類和回歸算法,顧名思義,決策樹是基於樹的結構來進行決策的。
以二分類為例,我們希望從給定訓練集中學得一個模型來對新的樣例進行分類。
以上面海洋生物為例
no surfacing:不浮出水面是否可以生存
flippers:是否有腳蹼
將表特征量化(是:1,否:0)
我們可以建立這樣一顆決策樹(后面結果證明,這是最佳的決策樹):
代碼實現
paython3.6,Spyder運行環境,每行代碼我基本都做了注釋,最終能生成最優決策樹結構,並用pyplot繪制了決策樹,以及該決策樹的葉子結點,樹的深度。
ID3算法的核心是在決策樹的各個結點上應用信息增益准則進行特征選擇。具體做法是:
- 從根節點開始,對結點計算所有可能特征的信息增益,選擇信息增益最大的特征作為結點的特征,並由該特征的不同取值構建子節點;
- 對子節點遞歸地調用以上方法,構建決策樹;
- 直到所有特征的信息增益均很小或者沒有特征可選時為止。
myTrees.py文件:

1 # -*- coding: utf-8 -*- 2 """ 3 Created on Thu Aug 2 17:09:34 2018 4 決策樹ID3的實現 5 @author: weixw 6 """ 7 from math import log 8 import operator 9 #原始數據 10 def createDataSet(): 11 dataSet = [[1, 1, 'yes'], 12 [1, 1, 'yes'], 13 [1, 0, 'no'], 14 [0, 1, 'no'], 15 [0, 1, 'no']] 16 labels = ['no surfacing','flippers'] 17 return dataSet, labels 18 19 #多數表決器 20 #列中相同值數量最多為結果 21 def majorityCnt(classList): 22 classCounts = {} 23 for value in classList: 24 if(value not in classCounts.keys()): 25 classCounts[value] = 0 26 classCounts[value] +=1 27 sortedClassCount = sorted(classCounts.iteritems(),key = operator.itemgetter(1),reverse =True) 28 return sortedClassCount[0][0] 29 30 31 #划分數據集 32 #dataSet:原始數據集 33 #axis:進行分割的指定列索引 34 #value:指定列中的值 35 def splitDataSet(dataSet,axis,value): 36 retDataSet= [] 37 for featDataVal in dataSet: 38 if featDataVal[axis] == value: 39 #下面兩行去除某一項指定列的值,很巧妙有沒有 40 reducedFeatVal = featDataVal[:axis] 41 reducedFeatVal.extend(featDataVal[axis+1:]) 42 retDataSet.append(reducedFeatVal) 43 return retDataSet 44 45 #計算香農熵 46 def calcShannonEnt(dataSet): 47 #數據集總項數 48 numEntries = len(dataSet) 49 #標簽計數對象初始化 50 labelCounts = {} 51 for featDataVal in dataSet: 52 #獲取數據集每一項的最后一列的標簽值 53 currentLabel = featDataVal[-1] 54 #如果當前標簽不在標簽存儲對象里,則初始化,然后計數 55 if currentLabel not in labelCounts.keys(): 56 labelCounts[currentLabel] = 0 57 labelCounts[currentLabel] += 1 58 #熵初始化 59 shannonEnt = 0.0 60 #遍歷標簽對象,求概率,計算熵 61 for key in labelCounts.keys(): 62 prop = labelCounts[key]/float(numEntries) 63 shannonEnt -= prop*log(prop,2) 64 return shannonEnt 65 66 #選出最優特征列索引 67 def chooseBestFeatureToSplit(dataSet): 68 #計算特征個數,dataSet最后一列是標簽屬性,不是特征量 69 numFeatures = len(dataSet[0])-1 70 #計算初始數據香農熵 71 baseEntropy = calcShannonEnt(dataSet) 72 #初始化信息增益,最優划分特征列索引 73 bestInfoGain = 0.0 74 bestFeatureIndex = -1 75 for i in range(numFeatures): 76 #獲取每一列數據 77 featList = [example[i] for example in dataSet] 78 #將每一列數據去重 79 uniqueVals = set(featList) 80 newEntropy = 0.0 81 for value in uniqueVals: 82 subDataSet = splitDataSet(dataSet,i,value) 83 #計算條件概率 84 prob = len(subDataSet)/float(len(dataSet)) 85 #計算條件熵 86 newEntropy +=prob*calcShannonEnt(subDataSet) 87 #計算信息增益 88 infoGain = baseEntropy - newEntropy 89 if(infoGain > bestInfoGain): 90 bestInfoGain = infoGain 91 bestFeatureIndex = i 92 return bestFeatureIndex 93 94 #決策樹創建 95 def createTree(dataSet,labels): 96 #獲取標簽屬性,dataSet最后一列,區別於labels標簽名稱 97 classList = [example[-1] for example in dataSet] 98 #樹極端終止條件判斷 99 #標簽屬性值全部相同,返回標簽屬性第一項值 100 if classList.count(classList[0]) == len(classList): 101 return classList[0] 102 #只有一個特征(1列) 103 if len(dataSet[0]) == 1: 104 return majorityCnt(classList) 105 #獲取最優特征列索引 106 bestFeatureIndex = chooseBestFeatureToSplit(dataSet) 107 #獲取最優索引對應的標簽名稱 108 bestFeatureLabel = labels[bestFeatureIndex] 109 #創建根節點 110 myTree = {bestFeatureLabel:{}} 111 #去除最優索引對應的標簽名,使labels標簽能正確遍歷 112 del(labels[bestFeatureIndex]) 113 #獲取最優列 114 bestFeature = [example[bestFeatureIndex] for example in dataSet] 115 uniquesVals = set(bestFeature) 116 for value in uniquesVals: 117 #子標簽名稱集合 118 subLabels = labels[:] 119 #遞歸 120 myTree[bestFeatureLabel][value] = createTree(splitDataSet(dataSet,bestFeatureIndex,value),subLabels) 121 return myTree 122 123 #獲取分類結果 124 #inputTree:決策樹字典 125 #featLabels:標簽列表 126 #testVec:測試向量 例如:簡單實例下某一路徑 [1,1] => yes(樹干值組合,從根結點到葉子節點) 127 def classify(inputTree,featLabels,testVec): 128 #獲取根結點名稱,將dict轉化為list 129 firstSide = list(inputTree.keys()) 130 #根結點名稱String類型 131 firstStr = firstSide[0] 132 #獲取根結點對應的子節點 133 secondDict = inputTree[firstStr] 134 #獲取根結點名稱在標簽列表中對應的索引 135 featIndex = featLabels.index(firstStr) 136 #由索引獲取向量表中的對應值 137 key = testVec[featIndex] 138 #獲取樹干向量后的對象 139 valueOfFeat = secondDict[key] 140 #判斷是子結點還是葉子節點:子結點就回調分類函數,葉子結點就是分類結果 141 #if type(valueOfFeat).__name__=='dict': 等價 if isinstance(valueOfFeat, dict): 142 if isinstance(valueOfFeat, dict): 143 classLabel = classify(valueOfFeat,featLabels,testVec) 144 else: 145 classLabel = valueOfFeat 146 return classLabel 147 148 149 #將決策樹分類器存儲在磁盤中,filename一般保存為txt格式 150 def storeTree(inputTree,filename): 151 import pickle 152 fw = open(filename,'wb+') 153 pickle.dump(inputTree,fw) 154 fw.close() 155 #將瓷盤中的對象加載出來,這里的filename就是上面函數中的txt文件 156 def grabTree(filename): 157 import pickle 158 fr = open(filename,'rb') 159 return pickle.load(fr) 160 161
treePlotter.py文件:

1 ''' 2 Created on Oct 14, 2010 3 4 @author: Peter Harrington 5 ''' 6 import matplotlib.pyplot as plt 7 8 decisionNode = dict(boxstyle="sawtooth", fc="0.8") 9 leafNode = dict(boxstyle="round4", fc="0.8") 10 arrow_args = dict(arrowstyle="<-") 11 12 #獲取樹的葉子節點 13 def getNumLeafs(myTree): 14 numLeafs = 0 15 #dict轉化為list 16 firstSides = list(myTree.keys()) 17 firstStr = firstSides[0] 18 secondDict = myTree[firstStr] 19 for key in secondDict.keys(): 20 #判斷是否是葉子節點(通過類型判斷,子類不存在,則類型為str;子類存在,則為dict) 21 if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes 22 numLeafs += getNumLeafs(secondDict[key]) 23 else: numLeafs +=1 24 return numLeafs 25 26 #獲取樹的層數 27 def getTreeDepth(myTree): 28 maxDepth = 0 29 #dict轉化為list 30 firstSides = list(myTree.keys()) 31 firstStr = firstSides[0] 32 secondDict = myTree[firstStr] 33 for key in secondDict.keys(): 34 if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes 35 thisDepth = 1 + getTreeDepth(secondDict[key]) 36 else: thisDepth = 1 37 if thisDepth > maxDepth: maxDepth = thisDepth 38 return maxDepth 39 40 def plotNode(nodeTxt, centerPt, parentPt, nodeType): 41 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', 42 xytext=centerPt, textcoords='axes fraction', 43 va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) 44 45 def plotMidText(cntrPt, parentPt, txtString): 46 xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] 47 yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] 48 createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) 49 50 def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on 51 numLeafs = getNumLeafs(myTree) #this determines the x width of this tree 52 depth = getTreeDepth(myTree) 53 firstSides = list(myTree.keys()) 54 firstStr = firstSides[0] #the text label for this node should be this 55 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) 56 plotMidText(cntrPt, parentPt, nodeTxt) 57 plotNode(firstStr, cntrPt, parentPt, decisionNode) 58 secondDict = myTree[firstStr] 59 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD 60 for key in secondDict.keys(): 61 if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes 62 plotTree(secondDict[key],cntrPt,str(key)) #recursion 63 else: #it's a leaf node print the leaf node 64 plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW 65 plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) 66 plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) 67 plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD 68 #if you do get a dictonary you know it's a tree, and the first element will be another dict 69 #繪制決策樹 70 def createPlot(inTree): 71 fig = plt.figure(1, facecolor='white') 72 fig.clf() 73 axprops = dict(xticks=[], yticks=[]) 74 createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks 75 #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 76 plotTree.totalW = float(getNumLeafs(inTree)) 77 plotTree.totalD = float(getTreeDepth(inTree)) 78 plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; 79 plotTree(inTree, (0.5,1.0), '') 80 plt.show() 81 82 #繪制樹的根節點和葉子節點(根節點形狀:長方形,葉子節點:橢圓形) 83 #def createPlot(): 84 # fig = plt.figure(1, facecolor='white') 85 # fig.clf() 86 # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 87 # plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode) 88 # plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode) 89 # plt.show() 90 91 def retrieveTree(i): 92 listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, 93 {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}} 94 ] 95 return listOfTrees[i] 96 97 #thisTree = retrieveTree(0) 98 #createPlot(thisTree) 99 #createPlot() 100 #myTree = retrieveTree(0) 101 #numLeafs =getNumLeafs(myTree) 102 #treeDepth =getTreeDepth(myTree) 103 #print(u"葉子節點數目:%d"% numLeafs) 104 #print(u"樹深度:%d"%treeDepth)
testTrees_3.py測試文件:

1 # -*- coding: utf-8 -*- 2 """ 3 Created on Fri Aug 3 19:52:10 2018 4 5 @author: weixw 6 """ 7 import myTrees as mt 8 import treePlotter as tp 9 #測試 10 dataSet, labels = mt.createDataSet() 11 #copy函數:新開辟一塊內存,然后將list的所有值復制到新開辟的內存中 12 labels1 = labels.copy() 13 #createTree函數中將labels1的值改變了,所以在分類測試時不能用labels1 14 myTree = mt.createTree(dataSet,labels1) 15 #保存樹到本地 16 mt.storeTree(myTree,'myTree.txt') 17 #在本地磁盤獲取樹 18 myTree = mt.grabTree('myTree.txt') 19 print (u"決策樹結構:%s"%myTree) 20 #繪制決策樹 21 print(u"繪制決策樹:") 22 tp.createPlot(myTree) 23 numLeafs =tp.getNumLeafs(myTree) 24 treeDepth =tp.getTreeDepth(myTree) 25 print(u"葉子節點數目:%d"% numLeafs) 26 print(u"樹深度:%d"%treeDepth) 27 #測試分類 簡單樣本數據3列 28 labelResult =mt.classify(myTree,labels,[1,1]) 29 print(u"[1,1] 測試結果為:%s"%labelResult) 30 labelResult =mt.classify(myTree,labels,[1,0]) 31 print(u"[1,0] 測試結果為:%s"%labelResult)
運行結果:
不要讓懶惰占據你的大腦,不要讓妥協拖垮你的人生。青春就是一張票,能不能趕上時代的快車,你的步伐掌握在你的腳下。