前言
本文詳細介紹機器學習分類算法中的決策樹算法,並全面詳解如何構造,表示,保存決策樹,以及如何使用決策樹進行分類等等問題。
為了全面的理解學習決策樹,本文篇幅較長,請耐心閱讀。
算法原理
每次依據不同的特征信息對數據集進行划分,划分的最終結果是一棵樹。
該樹的每個子樹存放一個划分集,而每個葉節點則表示最終分類結果,這樣一棵樹被稱為決策樹。
決策樹建好之后,帶着目標對象按照一定規則遍歷這個決策樹就能得到最終的分類結果。
該算法可以分為兩大部分:
1. 構建決策樹部分
2. 使用決策樹分類部分
其中,第一部分是重點難點。
決策樹構造偽代碼
1 # ============================================== 2 # 輸入: 3 # 數據集 4 # 輸出: 5 # 構造好的決策樹(也即訓練集) 6 # ============================================== 7 def 創建決策樹: 8 '創建決策樹' 9 10 if (數據集中所有樣本分類一致): 11 創建攜帶類標簽的葉子節點 12 else: 13 尋找划分數據集的最好特征 14 根據最好特征划分數據集 15 for 每個划分的數據集: 16 創建決策子樹(遞歸方式)
核心問題一:依據什么划分數據集
可采用ID3算法思路:如果以某種特種特征來划分數據集,會導致數據集發生最大程度的改變,那么就使用這種特征值來划分。
那么又該如何衡量數據集的變化程度呢?
可采用熵來進行衡量。這個字讀作shang,第一聲,不要讀成di啊,哈哈!
它用來衡量信息集的無序程度,其計算公式如下:

其中:
1. x是指分類。要注意決策樹的分類是離散的。
2. P(x)是指任一樣本為該分類的概率
顯然,與原數據集相比,熵差最大的划分集就是最優划分集。
對數據集求熵的代碼如下:
1 # ============================================== 2 # 輸入: 3 # dataSet: 數據集文件名(含路徑) 4 # 輸出: 5 # shannonEnt: 輸入數據集的香農熵 6 # ============================================== 7 def calcShannonEnt(dataSet): 8 '計算香農熵' 9 10 # 數據集個數 11 numEntries = len(dataSet) 12 # 標簽集合 13 labelCounts = {} 14 for featVec in dataSet: # 行遍歷數據集 15 # 當前標簽 16 currentLabel = featVec[-1] 17 # 加入標簽集合 18 if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 19 labelCounts[currentLabel] += 1 20 21 # 計算當前數據集的香農熵並返回 22 shannonEnt = 0.0 23 for key in labelCounts: 24 prob = float(labelCounts[key])/numEntries 25 shannonEnt -= prob * log(prob,2) 26 27 return shannonEnt
可用如下函數創建測試數據集並對其求熵:
1 # ============================================== 2 # 輸入: 3 # 空 4 # 輸出: 5 # dataSet: 測試數據集列表 6 # ============================================== 7 def createDataSet(): 8 '創建測試數據集' 9 10 dataSet = [[1, 1, 'yes'], 11 [1, 1, 'yes'], 12 [1, 0, 'no'], 13 [0, 1, 'no'], 14 [0, 1, 'no']] 15 16 return dataSet 17 18 def test(): 19 '測試' 20 21 # 創建測試數據集 22 myDat = createDataSet() 23 # 求出其熵並打印 24 print calcShannonEnt(myDat)
運行結果如下:

如果我們修改測試數據集的某些數據,讓其看起來顯得混亂點,則得到的熵的值會更大。
還有其他描述集合無序程度的方法,比如說基尼不純度等等,這里就不再討論了。
核心問題二:如何划分數據集
這涉及到一些細節上面的問題了,比如:每次划分是否需要剔除某些字段?如何對各種划分所得的熵差進行比較並進行最優划分等等。
首先是具體划分函數:
1 # ============================================== 2 # 輸入: 3 # dataSet: 訓練集文件名(含路徑) 4 # axis: 用於划分的特征的列數 5 # value: 划分值 6 # 輸出: 7 # retDataSet: 划分后的子列表 8 # ============================================== 9 def splitDataSet(dataSet, axis, value): 10 '數據集划分' 11 12 # 划分結果 13 retDataSet = [] 14 for featVec in dataSet: # 逐行遍歷數據集 15 if featVec[axis] == value: # 如果目標特征值等於value 16 # 抽取掉數據集中的目標特征值列 17 reducedFeatVec = featVec[:axis] 18 reducedFeatVec.extend(featVec[axis+1:]) 19 # 將抽取后的數據加入到划分結果列表中 20 retDataSet.append(reducedFeatVec) 21 22 return retDataSet
然后是選擇最優划分函數:
1 # =============================================== 2 # 輸入: 3 # dataSet: 數據集 4 # 輸出: 5 # bestFeature: 和原數據集熵差最大划分對應的特征的列號 6 # =============================================== 7 def chooseBestFeatureToSplit(dataSet): 8 '選擇最佳划分方案' 9 10 # 特征個數 11 numFeatures = len(dataSet[0]) - 1 12 # 原數據集香農熵 13 baseEntropy = calcShannonEnt(dataSet) 14 # 暫存最大熵增量 15 bestInfoGain = 0.0; 16 # 和原數據集熵差最大的划分對應的特征的列號 17 bestFeature = -1 18 19 for i in range(numFeatures): # 逐列遍歷數據集 20 # 獲取該列所有特征值 21 featList = [example[i] for example in dataSet] 22 # 將特征值列featList的值唯一化並保存到集合uniqueVals 23 uniqueVals = set(featList) 24 25 # 新划分法香農熵 26 newEntropy = 0.0 27 # 計算該特征划分下所有划分子集的香農熵,並疊加。 28 for value in uniqueVals: # 遍歷該特征列所有特征值 29 subDataSet = splitDataSet(dataSet, i, value) 30 prob = len(subDataSet)/float(len(dataSet)) 31 newEntropy += prob * calcShannonEnt(subDataSet) 32 33 # 保存所有划分法中,和原數據集熵差最大划分對應的特征的列號。 34 infoGain = baseEntropy - newEntropy 35 if (infoGain > bestInfoGain): 36 bestInfoGain = infoGain 37 bestFeature = i 38 39 return bestFeature
得到的結果是0:

而上面的代碼也看到,測試數據集為:
1 dataSet = [[1, 1, 'yes'], 2 [1, 1, 'yes'], 3 [1, 0, 'no'], 4 [0, 1, 'no'], 5 [0, 1, 'no']]
顯然,按照第0列特征划分會更加合理,區分度更大。
核心問題三:如何具體實現樹結構
通過對前面兩個問題的分析,划分數據集這一塊已經清楚明了了。
那么如何用這些多層次的划分子集搭建出一個樹結構呢?這部分更多涉及到編程技巧,某種程度上來說,就是用Python實現樹的問題。
在Python中,可以用字典來具體實現樹:字典的鍵存放節點信息,值存放分支及子樹/葉子節點信息。
比如說對於下面這個樹,用Python的字典表述就是:{'no surfacing' : {0, 'no', 1 : {'flippers' : {0 : 'no', 1 : 'yes'}}}}
如下構建樹部分代碼。該函數調用后將形成決策樹:
1 # =============================================== 2 # 輸入: 3 # classList: 類標簽集 4 # 輸出: 5 # sortedClassCount[0][0]: 出現次數最多的標簽 6 # =============================================== 7 def majorityCnt(classList): 8 '采用多數表決的方式求出classList中出現次數最多的類標簽' 9 10 classCount={} 11 for vote in classList: 12 if vote not in classCount.keys(): classCount[vote] = 0 13 classCount[vote] += 1 14 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) 15 16 return sortedClassCount[0][0] 17 18 # =============================================== 19 # 輸入: 20 # dataSet: 數據集 21 # labels: 划分標簽集 22 # 輸出: 23 # myTree: 生成的決策樹 24 # =============================================== 25 def createTree(dataSet,labels): 26 '創建決策樹' 27 28 # 獲得類標簽列表 29 classList = [example[-1] for example in dataSet] 30 31 # 遞歸終止條件一:如果數據集內所有分類一致 32 if classList.count(classList[0]) == len(classList): 33 return classList[0] 34 35 # 遞歸終止條件二:如果所有特征都划分完畢 36 if len(dataSet[0]) == 1: 37 # 將它們都歸為一類並返回 38 return majorityCnt(classList) 39 40 # 選擇最佳划分特征 41 bestFeat = chooseBestFeatureToSplit(dataSet) 42 # 最佳划分對應的划分標簽。注意不是分類標簽 43 bestFeatLabel = labels[bestFeat] 44 # 構建字典空樹 45 myTree = {bestFeatLabel:{}} 46 # 從划分標簽列表中刪掉划分后的元素 47 del(labels[bestFeat]) 48 # 獲取最佳划分對應特征的所有特征值 49 featValues = [example[bestFeat] for example in dataSet] 50 # 對特征值列表featValues唯一化,結果存於uniqueVals。 51 uniqueVals = set(featValues) 52 53 for value in uniqueVals: # 逐行遍歷特征值集合 54 # 保存所有划分標簽信息並將其伙同划分后的數據集傳遞進下一次遞歸 55 subLabels = labels[:] 56 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels) 57 58 return myTree
如下代碼可用於測試函數是否正常執行:
1 # ============================================== 2 # 輸入: 3 # 空 4 # 輸出: 5 # 用於測試的數據集和划分標簽集 6 # ============================================== 7 def createDataSet(): 8 '創建測試數據集' 9 10 dataSet = [[1, 1, 'yes'], 11 [1, 1, 'yes'], 12 [1, 0, 'no'], 13 [0, 1, 'no'], 14 [0, 1, 'no']] 15 labels = ['no surfacing', 'flippers'] 16 17 return dataSet, labels 18 19 def test(): 20 '測試' 21 22 myDat, labels = createDataSet() 23 myTree = createTree(myDat, labels) 24 print myTree
運行結果:

使用Matplotlib繪制樹形圖
當決策樹構建好了以后,自然需要用一種方式來顯示給開發人員。僅僅是一個字典表達式很難讓人滿意。
因此,可采用Matplotlib來繪制樹形圖。
這涉及到兩方面的知識:
1. 遍歷樹獲取樹的高度,葉子數等信息。
2. Matplotlib繪制圖像的一些API
對於第一部分的任務,可以用遞歸的方式遍歷字典樹,從而獲得樹的相關信息。
下面給出求樹的葉子樹及樹高的函數:
1 # =============================================== 2 # 輸入: 3 # myTree: 決策樹 4 # 輸出: 5 # numLeafs: 決策樹的葉子數 6 # =============================================== 7 def getNumLeafs(myTree): 8 '計算決策樹的葉子數' 9 10 # 葉子數 11 numLeafs = 0 12 # 節點信息 13 firstStr = myTree.keys()[0] 14 # 分支信息 15 secondDict = myTree[firstStr] 16 17 for key in secondDict.keys(): # 遍歷所有分支 18 # 子樹分支則遞歸計算 19 if type(secondDict[key]).__name__=='dict': 20 numLeafs += getNumLeafs(secondDict[key]) 21 # 葉子分支則葉子數+1 22 else: numLeafs +=1 23 24 return numLeafs 25 26 # =============================================== 27 # 輸入: 28 # myTree: 決策樹 29 # 輸出: 30 # maxDepth: 決策樹的深度 31 # =============================================== 32 def getTreeDepth(myTree): 33 '計算決策樹的深度' 34 35 # 最大深度 36 maxDepth = 0 37 # 節點信息 38 firstStr = myTree.keys()[0] 39 # 分支信息 40 secondDict = myTree[firstStr] 41 42 for key in secondDict.keys(): # 遍歷所有分支 43 # 子樹分支則遞歸計算 44 if type(secondDict[key]).__name__=='dict': 45 thisDepth = 1 + getTreeDepth(secondDict[key]) 46 # 葉子分支則葉子數+1 47 else: thisDepth = 1 48 49 # 更新最大深度 50 if thisDepth > maxDepth: maxDepth = thisDepth 51 52 return maxDepth
對於第二部分的任務 - 畫樹,其實本質就是畫點和畫線,下面給出基本的線畫法:
1 import matplotlib.pyplot as plt 2 3 decisionNode = dict(boxstyle="sawtooth", fc="0.8") 4 leafNode = dict(boxstyle="round4", fc="0.8") 5 arrow_args = dict(arrowstyle="<-") 6 7 # ================================================== 8 # 輸入: 9 # nodeTxt: 終端節點顯示內容 10 # centerPt: 終端節點坐標 11 # parentPt: 起始節點坐標 12 # nodeType: 終端節點樣式 13 # 輸出: 14 # 在圖形界面中顯示輸入參數指定樣式的線段(終端帶節點) 15 # ================================================== 16 def plotNode(nodeTxt, centerPt, parentPt, nodeType): 17 '畫線(末端帶一個點)' 18 19 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) 20 21 def createPlot(): 22 '繪制有向線段(末端帶一個節點)並顯示' 23 24 # 新建一個圖對象並清空 25 fig = plt.figure(1, facecolor='white') 26 fig.clf() 27 # 設置1行1列個圖區域,並選擇其中的第1個區域展示數據。 28 createPlot.ax1 = plt.subplot(111, frameon=False) 29 30 # 畫線(末端帶一個節點) 31 plotNode('decisionNode', (0.5, 0.1), (0.1, 0.5), decisionNode) 32 plotNode('leafNode', (0.8, 0.1), (0.3, 0.8), leafNode) 33 34 # 顯示繪制結果 35 plt.show()
調用 createPlot 函數即可顯示繪制結果:

下面,將這兩部分內容整合起來,寫出最終繪制樹的代碼:
1 import matplotlib.pyplot as plt 2 3 decisionNode = dict(boxstyle="sawtooth", fc="0.8") 4 leafNode = dict(boxstyle="round4", fc="0.8") 5 arrow_args = dict(arrowstyle="<-") 6 7 # ================================================== 8 # 輸入: 9 # nodeTxt: 終端節點顯示內容 10 # centerPt: 終端節點坐標 11 # parentPt: 起始節點坐標 12 # nodeType: 終端節點樣式 13 # 輸出: 14 # 在圖形界面中顯示輸入參數指定樣式的線段(終端帶節點) 15 # ================================================== 16 def plotNode(nodeTxt, centerPt, parentPt, nodeType): 17 '畫線(末端帶一個點)' 18 19 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) 20 21 # ================================================================= 22 # 輸入: 23 # cntrPt: 終端節點坐標 24 # parentPt: 起始節點坐標 25 # txtString: 待顯示文本內容 26 # 輸出: 27 # 在圖形界面指定位置(cntrPt和parentPt中間)顯示文本內容(txtString) 28 # ================================================================= 29 def plotMidText(cntrPt, parentPt, txtString): 30 '在指定位置添加文本' 31 32 # 中間位置坐標 33 xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] 34 yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] 35 36 createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) 37 38 # =================================== 39 # 輸入: 40 # myTree: 決策樹 41 # parentPt: 根節點坐標 42 # nodeTxt: 根節點坐標信息 43 # 輸出: 44 # 在圖形界面繪制決策樹 45 # =================================== 46 def plotTree(myTree, parentPt, nodeTxt): 47 '繪制決策樹' 48 49 # 當前樹的葉子數 50 numLeafs = getNumLeafs(myTree) 51 # 當前樹的節點信息 52 firstStr = myTree.keys()[0] 53 # 定位第一棵子樹的位置(這是蛋疼的一部分) 54 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) 55 56 # 繪制當前節點到子樹節點(含子樹節點)的信息 57 plotMidText(cntrPt, parentPt, nodeTxt) 58 plotNode(firstStr, cntrPt, parentPt, decisionNode) 59 60 # 獲取子樹信息 61 secondDict = myTree[firstStr] 62 # 開始繪制子樹,縱坐標-1。 63 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD 64 65 for key in secondDict.keys(): # 遍歷所有分支 66 # 子樹分支則遞歸 67 if type(secondDict[key]).__name__=='dict': 68 plotTree(secondDict[key],cntrPt,str(key)) 69 # 葉子分支則直接繪制 70 else: 71 plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW 72 plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) 73 plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) 74 75 # 子樹繪制完畢,縱坐標+1。 76 plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD 77 78 # ============================== 79 # 輸入: 80 # myTree: 決策樹 81 # 輸出: 82 # 在圖形界面顯示決策樹 83 # ============================== 84 def createPlot(inTree): 85 '顯示決策樹' 86 87 # 創建新的圖像並清空 - 無橫縱坐標 88 fig = plt.figure(1, facecolor='white') 89 fig.clf() 90 axprops = dict(xticks=[], yticks=[]) 91 createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) 92 93 # 樹的總寬度 高度 94 plotTree.totalW = float(getNumLeafs(inTree)) 95 plotTree.totalD = float(getTreeDepth(inTree)) 96 97 # 當前繪制節點的坐標 98 plotTree.xOff = -0.5/plotTree.totalW; 99 plotTree.yOff = 1.0; 100 101 # 繪制決策樹 102 plotTree(inTree, (0.5,1.0), '') 103 104 plt.show() 105 106 def test(): 107 '測試' 108 109 myDat, labels = createDataSet() 110 myTree = createTree(myDat, labels) 111 createPlot(myTree)
運行結果如下圖:

關於決策樹的存儲
這部分也很重要。
生成一個決策樹比較耗時間,誰也不想每次啟動程序都重新進行機器學習吧。那么能否將學習結果 - 決策樹保存到硬盤中去呢?
答案是肯定的,以下兩個函數分別實現了決策樹的存儲與打開:
1 # ====================== 2 # 輸入: 3 # myTree: 決策樹 4 # 輸出: 5 # 決策樹文件 6 # ====================== 7 def storeTree(inputTree,filename): 8 '保存決策樹' 9 10 import pickle 11 fw = open(filename,'w') 12 pickle.dump(inputTree,fw) 13 fw.close() 14 15 # ======================== 16 # 輸入: 17 # filename: 決策樹文件名 18 # 輸出: 19 # pickle.load(fr): 決策樹 20 # ======================== 21 def grabTree(filename): 22 '打開決策樹' 23 24 import pickle 25 fr = open(filename) 26 return pickle.load(fr)
使用決策樹進行分類
終於到了這一步,也是最終一步了。
拿到需要分類的數據后,遍歷決策樹直至葉子節點,即可得到分類結果,是不是很簡單呢?
下面給出遍歷及測試代碼:
1 # ======================== 2 # 輸入: 3 # inputTree: 決策樹文件名 4 # featLabels: 分類標簽集 5 # testVec: 待分類對象 6 # 輸出: 7 # classLabel: 分類結果 8 # ======================== 9 def classify(inputTree,featLabels,testVec): 10 '使用決策樹分類' 11 12 # 當前分類標簽 13 firstStr = inputTree.keys()[0] 14 secondDict = inputTree[firstStr] 15 # 找到當前分類標簽在分類標簽集中的下標 16 featIndex = featLabels.index(firstStr) 17 # 獲取待分類對象中當前分類的特征值 18 key = testVec[featIndex] 19 20 # 遍歷 21 valueOfFeat = secondDict[key] 22 23 # 子樹分支則遞歸 24 if isinstance(valueOfFeat, dict): 25 classLabel = classify(valueOfFeat, featLabels, testVec) 26 # 葉子分支則返回結果 27 else: classLabel = valueOfFeat 28 29 return classLabel 30 31 def test(): 32 '測試' 33 34 myDat, labels = createDataSet() 35 myTree = createTree(myDat, labels) 36 # 再創建一次數據的原因是創建決策樹函數會將labels值改動 37 myDat, labels = createDataSet() 38 print classify(myTree, labels, [1,1])
運行結果如下:

OK,一個完整的決策樹使用例子就實現了。
小結
1. 本文演示的是最經典ID3決策樹,但它在實際應用中存在過度匹配的問題。在以后的文章中會學習如何對決策樹進行裁剪。
2. 本文采用的ID3決策樹算法只能用於標稱型數據。對於數值型數據,需要使用Cart決策樹構造算法。這個算法將在以后進行深入學習。
