ID3,C4.5算法缺點
ID3決策樹可以有多個分支,但是不能處理特征值為連續的情況。
在ID3中,每次根據“最大信息熵增益”選取當前最佳的特征來分割數據,並按照該特征的所有取值來切分,
也就是說如果一個特征有4種取值,數據將被切分4份,一旦按某特征切分后,該特征在之后的算法執行中,
將不再起作用,所以有觀點認為這種切分方式過於迅速。
C4.5中是用信息增益比率(gain ratio)來作為選擇分支的准則。和ID3一樣,C4.5算法分類結果存在過擬合。
為了解決過擬合問題,這里介紹一種新的算法CART。
CART(classification and regression tree)
CART由特征選擇、樹的生成及剪枝組成,既可以用於分類也可以用於回歸。
分類:如晴天/陰天/雨天、用戶性別、郵件是否是垃圾郵件;
回歸:預測實數值,如明天的溫度、用戶的年齡等;
CART決策樹的生成就是遞歸地構建二叉決策樹的過程,對分類、以及剪枝采用信息增益最大化准則,這里信息增益采用的基尼指數公式,
當然也可以使用ID3的信息熵公式算法。
基尼指數
分類問題中,假設有K個類別,樣本點屬於第類的概率為
,則概率分布的基尼指數定義為
對於給定的樣本集合D,其基尼指數為
生成的二叉樹類似於
剪枝算法
CART剪枝算法從“完全生長”的決策樹的底端減去一些子樹,是決策樹變小(模型變簡單),從而能夠對未知數據有更准確的預測,防止過擬合。
后剪枝需要從訓練集生成一棵完整的決策樹,然后自底向上對非葉子節點進行考察。利用信息增益與給定閾值判斷是否將該節點對應的子樹替換成葉節點。
代碼實現
每個函數算法我基本上都做了較為詳細的注釋,希望對大家理解算法原理有所幫助。
因為沒有上傳附件功能,只能用笨辦法。將原始數據復制到本地txt文件中,然后將txt格式改成dataSet.csv文件,
放在代碼文件所在的路徑。

1 SepalLength,SepalWidth,PetalLength,PetalWidth,Name 2 5.1,3.5,1.4,0.2,setosa 3 4.9,3,1.4,0.2,setosa 4 4.7,3.2,1.3,0.2,setosa 5 4.6,3.1,1.5,0.2,setosa 6 5,3.6,1.4,0.2,setosa 7 5.4,3.9,1.7,0.4,setosa 8 4.6,3.4,1.4,0.3,setosa 9 5,3.4,1.5,0.2,setosa 10 4.4,2.9,1.4,0.2,setosa 11 4.9,3.1,1.5,0.1,setosa 12 5.4,3.7,1.5,0.2,setosa 13 4.8,3.4,1.6,0.2,setosa 14 4.8,3,1.4,0.1,setosa 15 4.3,3,1.1,0.1,setosa 16 5.8,4,1.2,0.2,setosa 17 5.7,4.4,1.5,0.4,setosa 18 5.4,3.9,1.3,0.4,setosa 19 5.1,3.5,1.4,0.3,setosa 20 5.7,3.8,1.7,0.3,setosa 21 5.1,3.8,1.5,0.3,setosa 22 5.4,3.4,1.7,0.2,setosa 23 5.1,3.7,1.5,0.4,setosa 24 4.6,3.6,1,0.2,setosa 25 5.1,3.3,1.7,0.5,setosa 26 4.8,3.4,1.9,0.2,setosa 27 5,3,1.6,0.2,setosa 28 5,3.4,1.6,0.4,setosa 29 5.2,3.5,1.5,0.2,setosa 30 5.2,3.4,1.4,0.2,setosa 31 4.7,3.2,1.6,0.2,setosa 32 4.8,3.1,1.6,0.2,setosa 33 5.4,3.4,1.5,0.4,setosa 34 5.2,4.1,1.5,0.1,setosa 35 5.5,4.2,1.4,0.2,setosa 36 4.9,3.1,1.5,0.1,setosa 37 5,3.2,1.2,0.2,setosa 38 5.5,3.5,1.3,0.2,setosa 39 4.9,3.1,1.5,0.1,setosa 40 4.4,3,1.3,0.2,setosa 41 5.1,3.4,1.5,0.2,setosa 42 5,3.5,1.3,0.3,setosa 43 4.5,2.3,1.3,0.3,setosa 44 4.4,3.2,1.3,0.2,setosa 45 5,3.5,1.6,0.6,setosa 46 5.1,3.8,1.9,0.4,setosa 47 4.8,3,1.4,0.3,setosa 48 5.1,3.8,1.6,0.2,setosa 49 4.6,3.2,1.4,0.2,setosa 50 5.3,3.7,1.5,0.2,setosa 51 5,3.3,1.4,0.2,setosa 52 7,3.2,4.7,1.4,versicolor 53 6.4,3.2,4.5,1.5,versicolor 54 6.9,3.1,4.9,1.5,versicolor 55 5.5,2.3,4,1.3,versicolor 56 6.5,2.8,4.6,1.5,versicolor 57 5.7,2.8,4.5,1.3,versicolor 58 6.3,3.3,4.7,1.6,versicolor 59 4.9,2.4,3.3,1,versicolor 60 6.6,2.9,4.6,1.3,versicolor 61 5.2,2.7,3.9,1.4,versicolor 62 5,2,3.5,1,versicolor 63 5.9,3,4.2,1.5,versicolor 64 6,2.2,4,1,versicolor 65 6.1,2.9,4.7,1.4,versicolor 66 5.6,2.9,3.6,1.3,versicolor 67 6.7,3.1,4.4,1.4,versicolor 68 5.6,3,4.5,1.5,versicolor 69 5.8,2.7,4.1,1,versicolor 70 6.2,2.2,4.5,1.5,versicolor 71 5.6,2.5,3.9,1.1,versicolor 72 5.9,3.2,4.8,1.8,versicolor 73 6.1,2.8,4,1.3,versicolor 74 6.3,2.5,4.9,1.5,versicolor 75 6.1,2.8,4.7,1.2,versicolor 76 6.4,2.9,4.3,1.3,versicolor 77 6.6,3,4.4,1.4,versicolor 78 6.8,2.8,4.8,1.4,versicolor 79 6.7,3,5,1.7,versicolor 80 6,2.9,4.5,1.5,versicolor 81 5.7,2.6,3.5,1,versicolor 82 5.5,2.4,3.8,1.1,versicolor 83 5.5,2.4,3.7,1,versicolor 84 5.8,2.7,3.9,1.2,versicolor 85 6,2.7,5.1,1.6,versicolor 86 5.4,3,4.5,1.5,versicolor 87 6,3.4,4.5,1.6,versicolor 88 6.7,3.1,4.7,1.5,versicolor 89 6.3,2.3,4.4,1.3,versicolor 90 5.6,3,4.1,1.3,versicolor 91 5.5,2.5,4,1.3,versicolor 92 5.5,2.6,4.4,1.2,versicolor 93 6.1,3,4.6,1.4,versicolor 94 5.8,2.6,4,1.2,versicolor 95 5,2.3,3.3,1,versicolor 96 5.6,2.7,4.2,1.3,versicolor 97 5.7,3,4.2,1.2,versicolor 98 5.7,2.9,4.2,1.3,versicolor 99 6.2,2.9,4.3,1.3,versicolor 100 5.1,2.5,3,1.1,versicolor 101 5.7,2.8,4.1,1.3,versicolor 102 6.3,3.3,6,2.5,virginica 103 5.8,2.7,5.1,1.9,virginica 104 7.1,3,5.9,2.1,virginica 105 6.3,2.9,5.6,1.8,virginica 106 6.5,3,5.8,2.2,virginica 107 7.6,3,6.6,2.1,virginica 108 4.9,2.5,4.5,1.7,virginica 109 7.3,2.9,6.3,1.8,virginica 110 6.7,2.5,5.8,1.8,virginica 111 7.2,3.6,6.1,2.5,virginica 112 6.5,3.2,5.1,2,virginica 113 6.4,2.7,5.3,1.9,virginica 114 6.8,3,5.5,2.1,virginica 115 5.7,2.5,5,2,virginica 116 5.8,2.8,5.1,2.4,virginica 117 6.4,3.2,5.3,2.3,virginica 118 6.5,3,5.5,1.8,virginica 119 7.7,3.8,6.7,2.2,virginica 120 7.7,2.6,6.9,2.3,virginica 121 6,2.2,5,1.5,virginica 122 6.9,3.2,5.7,2.3,virginica 123 5.6,2.8,4.9,2,virginica 124 7.7,2.8,6.7,2,virginica 125 6.3,2.7,4.9,1.8,virginica 126 6.7,3.3,5.7,2.1,virginica 127 7.2,3.2,6,1.8,virginica 128 6.2,2.8,4.8,1.8,virginica 129 6.1,3,4.9,1.8,virginica 130 6.4,2.8,5.6,2.1,virginica 131 7.2,3,5.8,1.6,virginica 132 7.4,2.8,6.1,1.9,virginica 133 7.9,3.8,6.4,2,virginica 134 6.4,2.8,5.6,2.2,virginica 135 6.3,2.8,5.1,1.5,virginica 136 6.1,2.6,5.6,1.4,virginica 137 7.7,3,6.1,2.3,virginica 138 6.3,3.4,5.6,2.4,virginica 139 6.4,3.1,5.5,1.8,virginica 140 6,3,4.8,1.8,virginica 141 6.9,3.1,5.4,2.1,virginica 142 6.7,3.1,5.6,2.4,virginica 143 6.9,3.1,5.1,2.3,virginica 144 5.8,2.7,5.1,1.9,virginica 145 6.8,3.2,5.9,2.3,virginica 146 6.7,3.3,5.7,2.5,virginica 147 6.7,3,5.2,2.3,virginica 148 6.3,2.5,5,1.9,virginica 149 6.5,3,5.2,2,virginica 150 6.2,3.4,5.4,2.3,virginica 151 5.9,3,5.1,1.8,virginica

1 # -*- coding: utf-8 -*- 2 """ 3 Created on Tue Aug 14 17:36:57 2018 4 5 @author: weixw 6 """ 7 import numpy as np 8 #定義樹結構,采用的二叉樹,左子樹:條件為true,右子樹:條件為false 9 #leftBranch:左子樹結點 10 #rightBranch:右子樹結點 11 #col:信息增益最大時對應的列索引 12 #value:最優列索引下,划分數據類型的值 13 #results:分類結果 14 #summary:信息增益最大時樣本信息 15 #data:信息增益最大時數據集 16 class Tree: 17 def __init__(self, leftBranch =None, rightBranch= None, col =-1, value =None, results =None, summary =None, data =None): 18 self.leftBranch = leftBranch 19 self.rightBranch = rightBranch 20 self.col = col 21 self.value = value 22 self.results = results 23 self.summary = summary 24 self.data = data 25 26 def __str__(self): 27 print(u"列號:%d"%self.col) 28 print(u"列划分值:%s"%self.value) 29 print(u"樣本信息:%s"%self.summary) 30 return "" 31 32 33 34 #划分數據集 35 def splitDataSet(dataSet, value, column): 36 leftList=[] 37 rightList=[] 38 #判斷value是否是數值型 39 if(isinstance(value, int) or isinstance(value, float)): 40 #遍歷每一行數據 41 for rowData in dataSet: 42 #如果某一行指定列值>=value,則將該行數據保存在leftList中,否則保存在rightList中 43 if(rowData[column] >= value): 44 leftList.append(rowData) 45 else: 46 rightList.append(rowData) 47 #value為標稱型 48 else: 49 #遍歷每一行數據 50 for rowData in dataSet: 51 #如果某一行指定列值==value,則將該行數據保存在leftList中,否則保存在rightList中 52 if(rowData[column] == value): 53 leftList.append(rowData) 54 else: 55 rightList.append(rowData) 56 return leftList, rightList 57 58 #統計標簽類每個樣本個數 59 ''' 60 該函數是計算gini值的輔助函數,假設輸入的dataSet為為['A', 'B', 'C', 'A', 'A', 'D'], 61 則輸出為['A':3,' B':1, 'C':1, 'D':1],這樣分類統計dataSet中每個類別的數量 62 ''' 63 def calculateDiffCount(dataSet): 64 results = {} 65 for data in dataSet: 66 # data[-1] 是數據集最后一列,也就是標簽類 67 if data[-1] not in results: 68 results.setdefault(data[-1], 1) 69 else: 70 results[data[-1]] += 1 71 return results 72 73 74 #基尼指數公式實現 75 def gini(dataSet): 76 # 計算gini的值(Calculate GINI) 77 #數據所有行 78 length = len(dataSet) 79 #標簽列合並后的數據集 80 results = calculateDiffCount(dataSet) 81 imp = 0.0 82 for i in results: 83 imp += results[i] / length * results[i] / length 84 return 1 - imp 85 86 #生成決策樹 87 '''算法步驟''' 88 '''根據訓練數據集,從根結點開始,遞歸地對每個結點進行以下操作,構建二叉決策樹: 89 1 設結點的訓練數據集為D,計算現有特征對該數據集的信息增益。此時,對每一個特征A,對其可能取的 90 每個值a,根據樣本點對A >=a 的測試為“是”或“否”將D分割成D1和D2兩部分,利用基尼指數計算信息增益。 91 2 在所有可能的特征A以及它們所有可能的切分點a中,選擇信息增益最大的特征及其對應的切分點作為最優特征 92 與最優切分點,依據最優特征與最優切分點,從現結點生成兩個子結點,將訓練數據集依特征分配到兩個子結點中去。 93 3 對兩個子結點遞歸地調用1,2,直至滿足停止條件。 94 4 生成CART決策樹。 95 ''''''''''''''''''''' 96 #evaluationFunc= gini :采用的是基尼指數來衡量信息關注度 97 def buildDecisionTree(dataSet, evaluationFunc = gini): 98 #計算基礎數據集的基尼指數 99 baseGain = evaluationFunc(dataSet) 100 #計算每一行的長度(也就是列總數) 101 columnLength = len(dataSet[0]) 102 #計算數據項總數 103 rowLength = len(dataSet) 104 #初始化 105 bestGain = 0.0 #信息增益最大值 106 bestValue = None #信息增益最大時的列索引,以及划分數據集的樣本值 107 bestSet = None # 信息增益最大,聽過樣本值划分數據集后的數據子集 108 #標簽列除外(最后一列),遍歷每一列數據 109 for col in range(columnLength -1): 110 #獲取指定列數據 111 colSet = [example[col] for example in dataSet] 112 #獲取指定列樣本唯一值 113 uniqueColSet = set(colSet) 114 #遍歷指定列樣本集 115 for value in uniqueColSet: 116 #分割數據集 117 leftDataSet, rightDataSet = splitDataSet(dataSet, value, col) 118 #計算子數據集概率,python3 "/"除號結果為小數 119 prop = len(leftDataSet)/rowLength 120 #計算信息增益 121 infoGain = baseGain - prop*evaluationFunc(leftDataSet) - (1 - prop)*evaluationFunc(rightDataSet) 122 #找出信息增益最大時的列索引,value,數據子集 123 if(infoGain > bestGain): 124 bestGain = infoGain 125 bestValue = (col, value) 126 bestSet = (leftDataSet, rightDataSet) 127 #結點信息 128 # nodeDescription = {'impurity:%.3f'%baseGain,'sample:%d'%rowLength} 129 nodeDescription = {'impurity': '%.3f' % baseGain, 'sample': '%d' % rowLength} 130 #數據行標簽類別不一致,可以繼續分類 131 #遞歸必須有終止條件 132 if bestGain > 0: 133 #遞歸,生成左子樹結點,右子樹結點 134 leftBranch = buildDecisionTree(bestSet[0], evaluationFunc) 135 rightBranch = buildDecisionTree(bestSet[1], evaluationFunc) 136 return Tree(leftBranch = leftBranch, rightBranch = rightBranch, col = bestValue[0] 137 , value = bestValue[1], summary = nodeDescription, data = bestSet) 138 else: 139 #數據行標簽類別都相同,分類終止 140 return Tree(results = calculateDiffCount(dataSet), summary = nodeDescription, data = dataSet) 141 142 def createTree(dataSet, evaluationFunc=gini): 143 # 遞歸建立決策樹, 當gain=0,時停止回歸 144 #計算基礎數據集的基尼指數 145 baseGain = evaluationFunc(dataSet) 146 #計算每一行的長度(也就是列總數) 147 columnLength = len(dataSet[0]) 148 #計算數據項總數 149 rowLength = len(dataSet) 150 #初始化 151 bestGain = 0.0 #信息增益最大值 152 bestValue = None #信息增益最大時的列索引,以及划分數據集的樣本值 153 bestSet = None # 信息增益最大,聽過樣本值划分數據集后的數據子集 154 #標簽列除外(最后一列),遍歷每一列數據 155 for col in range(columnLength -1): 156 #獲取指定列數據 157 colSet = [example[col] for example in dataSet] 158 #獲取指定列樣本唯一值 159 uniqueColSet = set(colSet) 160 #遍歷指定列樣本集 161 for value in uniqueColSet: 162 #分割數據集 163 leftDataSet, rightDataSet = splitDataSet(dataSet, value, col) 164 #計算子數據集概率,python3 "/"除號結果為小數 165 prop = len(leftDataSet)/rowLength 166 #計算信息增益 167 infoGain = baseGain - prop*evaluationFunc(leftDataSet) - (1 - prop)*evaluationFunc(rightDataSet) 168 #找出信息增益最大時的列索引,value,數據子集 169 if(infoGain > bestGain): 170 bestGain = infoGain 171 bestValue = (col, value) 172 bestSet = (leftDataSet, rightDataSet) 173 174 impurity = u'%.3f' % baseGain 175 sample = '%d' % rowLength 176 177 if bestGain > 0: 178 bestFeatLabel =u'serial:%s\nimpurity:%s\nsample:%s'%(bestValue[0], impurity,sample) 179 myTree = {bestFeatLabel:{}} 180 myTree[bestFeatLabel][bestValue[1]] = createTree(bestSet[0], evaluationFunc) 181 myTree[bestFeatLabel]['no'] = createTree(bestSet[1], evaluationFunc) 182 return myTree 183 else:#遞歸需要返回值 184 bestFeatValue =u'%s\nimpurity:%s\nsample:%s'%(str(calculateDiffCount(dataSet)), impurity,sample) 185 return bestFeatValue 186 187 #分類測試: 188 '''根據給定測試數據遍歷二叉樹,找到符合條件的葉子結點''' 189 '''例如測試數據為[5.9,3,4.2,1.75],按照訓練數據生成的決策樹分類的順序為 190 第2列對應測試數據4.2 =>與決策樹根結點(2)的value(3)比較,>=3則遍歷左子樹,否則遍歷右子樹, 191 葉子結點就是結果''' 192 def classify(data, tree): 193 #判斷是否是葉子結點,是就返回葉子結點相關信息,否就繼續遍歷 194 if tree.results != None: 195 return u"%s\n%s"%(tree.results, tree.summary) 196 else: 197 branch = None 198 v = data[tree.col] 199 #數值型數據 200 if isinstance(v, int) or isinstance(v, float): 201 if v >= tree.value: 202 branch = tree.leftBranch 203 else: 204 branch = tree.rightBranch 205 else:#標稱型數據 206 if v == tree.value: 207 branch = tree.leftBranch 208 else: 209 branch = tree.rightBranch 210 return classify(data, branch) 211 212 def loadCSV(fileName): 213 def convertTypes(s): 214 s = s.strip() 215 try: 216 return float(s) if '.' in s else int(s) 217 except ValueError: 218 return s 219 data = np.loadtxt(fileName, dtype='str', delimiter=',') 220 data = data[1:, :] 221 dataSet =([[convertTypes(item) for item in row] for row in data]) 222 return dataSet 223 224 #多數表決器 225 #列中相同值數量最多為結果 226 def majorityCnt(classList): 227 import operator 228 classCounts = {} 229 for value in classList: 230 if(value not in classCounts.keys()): 231 classCounts[value] = 0 232 classCounts[value] +=1 233 sortedClassCount = sorted(classCounts.items(),key = operator.itemgetter(1),reverse =True) 234 return sortedClassCount[0][0] 235 236 #剪枝算法(前序遍歷方式:根=>左子樹=>右子樹) 237 '''算法步驟 238 1. 從二叉樹的根結點出發,遞歸調用剪枝算法,直至左、右結點都是葉子結點 239 2. 計算父節點(子結點為葉子結點)的信息增益infoGain 240 3. 如果infoGain < miniGain,則選取樣本多的葉子結點來取代父節點 241 4. 循環1,2,3,直至遍歷完整棵樹 242 ''''''''' 243 def prune(tree, miniGain, evaluationFunc = gini): 244 print(u"當前結點信息:") 245 print(str(tree)) 246 #如果當前結點的左子樹不是葉子結點,遍歷左子樹 247 if(tree.leftBranch.results == None): 248 print(u"左子樹結點信息:") 249 print(str(tree.leftBranch)) 250 prune(tree.leftBranch, miniGain, evaluationFunc) 251 #如果當前結點的右子樹不是葉子結點,遍歷右子樹 252 if(tree.rightBranch.results == None): 253 print(u"右子樹結點信息:") 254 print(str(tree.rightBranch)) 255 prune(tree.rightBranch, miniGain, evaluationFunc) 256 #左子樹和右子樹都是葉子結點 257 if(tree.leftBranch.results != None and tree.rightBranch.results != None): 258 #計算左葉子結點數據長度 259 leftLen = len(tree.leftBranch.data) 260 #計算右葉子結點數據長度 261 rightLen = len(tree.rightBranch.data) 262 #計算左葉子結點概率 263 leftProp = leftLen/(leftLen + rightLen) 264 #計算該結點的信息增益(子類是葉子結點) 265 infoGain = (evaluationFunc(tree.leftBranch.data + tree.rightBranch.data) - 266 leftProp*evaluationFunc(tree.leftBranch.data) - (1 - leftProp)*evaluationFunc(tree.rightBranch.data)) 267 #信息增益 < 給定閾值,則說明葉子結點與其父結點特征差別不大,可以剪枝 268 if(infoGain < miniGain): 269 #合並左右葉子結點數據 270 dataSet = tree.leftBranch.data + tree.rightBranch.data 271 #獲取標簽列 272 classLabels = [example[-1] for example in dataSet] 273 #找到樣本最多的標簽值 274 keyLabel = majorityCnt(classLabels) 275 #判斷標簽值是左右葉子結點哪一個 276 if keyLabel in tree.leftBranch.results: 277 #左葉子結點取代父結點 278 tree.data = tree.leftBranch.data 279 tree.results = tree.leftBranch.results 280 tree.summary = tree.leftBranch.summary 281 else: 282 #右葉子結點取代父結點 283 tree.data = tree.rightBranch.data 284 tree.results = tree.rightBranch.results 285 tree.summary = tree.rightBranch.summary 286 tree.leftBranch = None 287 tree.rightBranch = None 288 289 290

1 ''' 2 Created on Oct 14, 2010 3 4 @author: Peter Harrington 5 ''' 6 import matplotlib.pyplot as plt 7 8 decisionNode = dict(boxstyle="sawtooth", fc="0.8") 9 leafNode = dict(boxstyle="circle", fc="0.7") 10 arrow_args = dict(arrowstyle="<-") 11 12 #獲取樹的葉子節點 13 def getNumLeafs(myTree): 14 numLeafs = 0 15 #dict轉化為list 16 firstSides = list(myTree.keys()) 17 firstStr = firstSides[0] 18 secondDict = myTree[firstStr] 19 for key in secondDict.keys(): 20 #判斷是否是葉子節點(通過類型判斷,子類不存在,則類型為str;子類存在,則為dict) 21 if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes 22 numLeafs += getNumLeafs(secondDict[key]) 23 else: numLeafs +=1 24 return numLeafs 25 26 #獲取樹的層數 27 def getTreeDepth(myTree): 28 maxDepth = 0 29 #dict轉化為list 30 firstSides = list(myTree.keys()) 31 firstStr = firstSides[0] 32 secondDict = myTree[firstStr] 33 for key in secondDict.keys(): 34 if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes 35 thisDepth = 1 + getTreeDepth(secondDict[key]) 36 else: thisDepth = 1 37 if thisDepth > maxDepth: maxDepth = thisDepth 38 return maxDepth 39 40 def plotNode(nodeTxt, centerPt, parentPt, nodeType): 41 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', 42 xytext=centerPt, textcoords='axes fraction', 43 va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) 44 45 def plotMidText(cntrPt, parentPt, txtString): 46 xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] 47 yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] 48 createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) 49 50 def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on 51 numLeafs = getNumLeafs(myTree) #this determines the x width of this tree 52 depth = getTreeDepth(myTree) 53 firstSides = list(myTree.keys()) 54 firstStr = firstSides[0] #the text label for this node should be this 55 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) 56 plotMidText(cntrPt, parentPt, nodeTxt) 57 plotNode(firstStr, cntrPt, parentPt, decisionNode) 58 secondDict = myTree[firstStr] 59 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD 60 for key in secondDict.keys(): 61 if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes 62 plotTree(secondDict[key],cntrPt,str(key)) #recursion 63 else: #it's a leaf node print the leaf node 64 plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW 65 plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) 66 plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) 67 plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD 68 #if you do get a dictonary you know it's a tree, and the first element will be another dict 69 #繪制決策樹 樣例1 70 def createPlot(inTree): 71 fig = plt.figure(1, facecolor='white') 72 fig.clf() 73 axprops = dict(xticks=[], yticks=[]) 74 createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks 75 #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 76 #寬,高間距 77 plotTree.totalW = float(getNumLeafs(inTree))-3 78 plotTree.totalD = float(getTreeDepth(inTree))-2 79 # plotTree.totalW = float(getNumLeafs(inTree)) 80 # plotTree.totalD = float(getTreeDepth(inTree)) 81 plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; 82 plotTree(inTree, (0.95,1.0), '') 83 plt.show() 84 85 #繪制決策樹 樣例2 86 def createPlot1(inTree): 87 fig = plt.figure(1, facecolor='white') 88 fig.clf() 89 axprops = dict(xticks=[], yticks=[]) 90 createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks 91 #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 92 #寬,高間距 93 plotTree.totalW = float(getNumLeafs(inTree))-4.5 94 plotTree.totalD = float(getTreeDepth(inTree)) -3 95 plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; 96 plotTree(inTree, (1.0,1.0), '') 97 plt.show() 98 99 #繪制樹的根節點和葉子節點(根節點形狀:長方形,葉子節點:橢圓形) 100 #def createPlot(): 101 # fig = plt.figure(1, facecolor='white') 102 # fig.clf() 103 # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 104 # plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode) 105 # plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode) 106 # plt.show() 107 108 def retrieveTree(i): 109 listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, 110 {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}} 111 ] 112 return listOfTrees[i] 113 114 #thisTree = retrieveTree(0) 115 #createPlot(thisTree) 116 #createPlot() 117 #myTree = retrieveTree(0) 118 #numLeafs =getNumLeafs(myTree) 119 #treeDepth =getTreeDepth(myTree) 120 #print(u"葉子節點數目:%d"% numLeafs) 121 #print(u"樹深度:%d"%treeDepth)

1 # -*- coding: utf-8 -*- 2 """ 3 Created on Wed Aug 15 14:16:59 2018 4 5 @author: weixw 6 """ 7 import myCart as mc 8 if __name__ == '__main__': 9 import treePlotter as tp 10 dataSet = mc.loadCSV("dataSet.csv") 11 myTree = mc.createTree(dataSet, evaluationFunc=gini) 12 print(u"myTree:%s"%myTree) 13 #繪制決策樹 14 print(u"繪制決策樹:") 15 tp.createPlot1(myTree) 16 decisionTree = mc.buildDecisionTree(dataSet, evaluationFunc=gini) 17 testData = [5.9,3,4.2,1.75] 18 r = mc.classify(testData, decisionTree) 19 print(u"分類后測試結果:") 20 print(r) 21 print() 22 mc.prune(decisionTree, 0.4) 23 r1 = mc.classify(testData, decisionTree) 24 print(u"剪枝后測試結果:") 25 print(r1)
運行結果
為什么我要再寫個createTree(dataSet, evaluationFunc=gini)函數,是因為繪制決策樹createPlot1(myTree)輸入參數需要是json結構數據。
將生成的決策樹變為可視圖形,這樣更直觀。
當然,也可以將自定義樹對象信息打印出來,我在代碼里已加入打印語句。
打印結果如下,因為屏幕的原因,沒有全部粘貼出來,大家可以對照決策樹繪制圖,這樣可以相互印證,加深理解。
在未做剪枝處理時的分類測試結果如下:
剪枝處理后的分類測試結果:
可以看出,{'versicolor': 47}取代了父結點serial:3,成為新的葉子結點。