Python機器學習(二十)決策樹系列三—CART原理與代碼實現


 

ID3,C4.5算法缺點

ID3決策樹可以有多個分支,但是不能處理特征值為連續的情況。

在ID3中,每次根據“最大信息熵增益”選取當前最佳的特征來分割數據,並按照該特征的所有取值來切分,

也就是說如果一個特征有4種取值,數據將被切分4份,一旦按某特征切分后,該特征在之后的算法執行中,

將不再起作用,所以有觀點認為這種切分方式過於迅速。

C4.5中是用信息增益比率(gain ratio)來作為選擇分支的准則。和ID3一樣,C4.5算法分類結果存在過擬合。

為了解決過擬合問題,這里介紹一種新的算法CART。

CART(classification and regression tree)

CART由特征選擇、樹的生成及剪枝組成,既可以用於分類也可以用於回歸。

分類:如晴天/陰天/雨天、用戶性別、郵件是否是垃圾郵件; 

回歸:預測實數值,如明天的溫度、用戶的年齡等; 

 

CART決策樹的生成就是遞歸地構建二叉決策樹的過程,對分類、以及剪枝采用信息增益最大化准則,這里信息增益采用的基尼指數公式,

當然也可以使用ID3的信息熵公式算法。

基尼指數

分類問題中,假設有K個類別,樣本點屬於第k類的概率為p_k,則概率分布的基尼指數定義為

                  

 

對於給定的樣本集合D,其基尼指數為

                  

 

生成的二叉樹類似於

      

剪枝算法

CART剪枝算法從“完全生長”的決策樹的底端減去一些子樹,是決策樹變小(模型變簡單),從而能夠對未知數據有更准確的預測,防止過擬合。

后剪枝需要從訓練集生成一棵完整的決策樹,然后自底向上對非葉子節點進行考察。利用信息增益與給定閾值判斷是否將該節點對應的子樹替換成葉節點。

   

 

代碼實現

每個函數算法我基本上都做了較為詳細的注釋,希望對大家理解算法原理有所幫助。

因為沒有上傳附件功能,只能用笨辦法。將原始數據復制到本地txt文件中,然后將txt格式改成dataSet.csv文件,

放在代碼文件所在的路徑。

復制代碼
  1 SepalLength,SepalWidth,PetalLength,PetalWidth,Name
  2 5.1,3.5,1.4,0.2,setosa
  3 4.9,3,1.4,0.2,setosa
  4 4.7,3.2,1.3,0.2,setosa
  5 4.6,3.1,1.5,0.2,setosa
  6 5,3.6,1.4,0.2,setosa
  7 5.4,3.9,1.7,0.4,setosa
  8 4.6,3.4,1.4,0.3,setosa
  9 5,3.4,1.5,0.2,setosa
 10 4.4,2.9,1.4,0.2,setosa
 11 4.9,3.1,1.5,0.1,setosa
 12 5.4,3.7,1.5,0.2,setosa
 13 4.8,3.4,1.6,0.2,setosa
 14 4.8,3,1.4,0.1,setosa
 15 4.3,3,1.1,0.1,setosa
 16 5.8,4,1.2,0.2,setosa
 17 5.7,4.4,1.5,0.4,setosa
 18 5.4,3.9,1.3,0.4,setosa
 19 5.1,3.5,1.4,0.3,setosa
 20 5.7,3.8,1.7,0.3,setosa
 21 5.1,3.8,1.5,0.3,setosa
 22 5.4,3.4,1.7,0.2,setosa
 23 5.1,3.7,1.5,0.4,setosa
 24 4.6,3.6,1,0.2,setosa
 25 5.1,3.3,1.7,0.5,setosa
 26 4.8,3.4,1.9,0.2,setosa
 27 5,3,1.6,0.2,setosa
 28 5,3.4,1.6,0.4,setosa
 29 5.2,3.5,1.5,0.2,setosa
 30 5.2,3.4,1.4,0.2,setosa
 31 4.7,3.2,1.6,0.2,setosa
 32 4.8,3.1,1.6,0.2,setosa
 33 5.4,3.4,1.5,0.4,setosa
 34 5.2,4.1,1.5,0.1,setosa
 35 5.5,4.2,1.4,0.2,setosa
 36 4.9,3.1,1.5,0.1,setosa
 37 5,3.2,1.2,0.2,setosa
 38 5.5,3.5,1.3,0.2,setosa
 39 4.9,3.1,1.5,0.1,setosa
 40 4.4,3,1.3,0.2,setosa
 41 5.1,3.4,1.5,0.2,setosa
 42 5,3.5,1.3,0.3,setosa
 43 4.5,2.3,1.3,0.3,setosa
 44 4.4,3.2,1.3,0.2,setosa
 45 5,3.5,1.6,0.6,setosa
 46 5.1,3.8,1.9,0.4,setosa
 47 4.8,3,1.4,0.3,setosa
 48 5.1,3.8,1.6,0.2,setosa
 49 4.6,3.2,1.4,0.2,setosa
 50 5.3,3.7,1.5,0.2,setosa
 51 5,3.3,1.4,0.2,setosa
 52 7,3.2,4.7,1.4,versicolor
 53 6.4,3.2,4.5,1.5,versicolor
 54 6.9,3.1,4.9,1.5,versicolor
 55 5.5,2.3,4,1.3,versicolor
 56 6.5,2.8,4.6,1.5,versicolor
 57 5.7,2.8,4.5,1.3,versicolor
 58 6.3,3.3,4.7,1.6,versicolor
 59 4.9,2.4,3.3,1,versicolor
 60 6.6,2.9,4.6,1.3,versicolor
 61 5.2,2.7,3.9,1.4,versicolor
 62 5,2,3.5,1,versicolor
 63 5.9,3,4.2,1.5,versicolor
 64 6,2.2,4,1,versicolor
 65 6.1,2.9,4.7,1.4,versicolor
 66 5.6,2.9,3.6,1.3,versicolor
 67 6.7,3.1,4.4,1.4,versicolor
 68 5.6,3,4.5,1.5,versicolor
 69 5.8,2.7,4.1,1,versicolor
 70 6.2,2.2,4.5,1.5,versicolor
 71 5.6,2.5,3.9,1.1,versicolor
 72 5.9,3.2,4.8,1.8,versicolor
 73 6.1,2.8,4,1.3,versicolor
 74 6.3,2.5,4.9,1.5,versicolor
 75 6.1,2.8,4.7,1.2,versicolor
 76 6.4,2.9,4.3,1.3,versicolor
 77 6.6,3,4.4,1.4,versicolor
 78 6.8,2.8,4.8,1.4,versicolor
 79 6.7,3,5,1.7,versicolor
 80 6,2.9,4.5,1.5,versicolor
 81 5.7,2.6,3.5,1,versicolor
 82 5.5,2.4,3.8,1.1,versicolor
 83 5.5,2.4,3.7,1,versicolor
 84 5.8,2.7,3.9,1.2,versicolor
 85 6,2.7,5.1,1.6,versicolor
 86 5.4,3,4.5,1.5,versicolor
 87 6,3.4,4.5,1.6,versicolor
 88 6.7,3.1,4.7,1.5,versicolor
 89 6.3,2.3,4.4,1.3,versicolor
 90 5.6,3,4.1,1.3,versicolor
 91 5.5,2.5,4,1.3,versicolor
 92 5.5,2.6,4.4,1.2,versicolor
 93 6.1,3,4.6,1.4,versicolor
 94 5.8,2.6,4,1.2,versicolor
 95 5,2.3,3.3,1,versicolor
 96 5.6,2.7,4.2,1.3,versicolor
 97 5.7,3,4.2,1.2,versicolor
 98 5.7,2.9,4.2,1.3,versicolor
 99 6.2,2.9,4.3,1.3,versicolor
100 5.1,2.5,3,1.1,versicolor
101 5.7,2.8,4.1,1.3,versicolor
102 6.3,3.3,6,2.5,virginica
103 5.8,2.7,5.1,1.9,virginica
104 7.1,3,5.9,2.1,virginica
105 6.3,2.9,5.6,1.8,virginica
106 6.5,3,5.8,2.2,virginica
107 7.6,3,6.6,2.1,virginica
108 4.9,2.5,4.5,1.7,virginica
109 7.3,2.9,6.3,1.8,virginica
110 6.7,2.5,5.8,1.8,virginica
111 7.2,3.6,6.1,2.5,virginica
112 6.5,3.2,5.1,2,virginica
113 6.4,2.7,5.3,1.9,virginica
114 6.8,3,5.5,2.1,virginica
115 5.7,2.5,5,2,virginica
116 5.8,2.8,5.1,2.4,virginica
117 6.4,3.2,5.3,2.3,virginica
118 6.5,3,5.5,1.8,virginica
119 7.7,3.8,6.7,2.2,virginica
120 7.7,2.6,6.9,2.3,virginica
121 6,2.2,5,1.5,virginica
122 6.9,3.2,5.7,2.3,virginica
123 5.6,2.8,4.9,2,virginica
124 7.7,2.8,6.7,2,virginica
125 6.3,2.7,4.9,1.8,virginica
126 6.7,3.3,5.7,2.1,virginica
127 7.2,3.2,6,1.8,virginica
128 6.2,2.8,4.8,1.8,virginica
129 6.1,3,4.9,1.8,virginica
130 6.4,2.8,5.6,2.1,virginica
131 7.2,3,5.8,1.6,virginica
132 7.4,2.8,6.1,1.9,virginica
133 7.9,3.8,6.4,2,virginica
134 6.4,2.8,5.6,2.2,virginica
135 6.3,2.8,5.1,1.5,virginica
136 6.1,2.6,5.6,1.4,virginica
137 7.7,3,6.1,2.3,virginica
138 6.3,3.4,5.6,2.4,virginica
139 6.4,3.1,5.5,1.8,virginica
140 6,3,4.8,1.8,virginica
141 6.9,3.1,5.4,2.1,virginica
142 6.7,3.1,5.6,2.4,virginica
143 6.9,3.1,5.1,2.3,virginica
144 5.8,2.7,5.1,1.9,virginica
145 6.8,3.2,5.9,2.3,virginica
146 6.7,3.3,5.7,2.5,virginica
147 6.7,3,5.2,2.3,virginica
148 6.3,2.5,5,1.9,virginica
149 6.5,3,5.2,2,virginica
150 6.2,3.4,5.4,2.3,virginica
151 5.9,3,5.1,1.8,virginica
復制代碼
復制代碼
  1 # -*- coding: utf-8 -*-
  2 """
  3 Created on Tue Aug 14 17:36:57 2018
  4 
  5 @author: weixw
  6 """
  7 import numpy as np
  8 #定義樹結構,采用的二叉樹,左子樹:條件為true,右子樹:條件為false
  9 #leftBranch:左子樹結點
 10 #rightBranch:右子樹結點
 11 #col:信息增益最大時對應的列索引
 12 #value:最優列索引下,划分數據類型的值
 13 #results:分類結果
 14 #summary:信息增益最大時樣本信息
 15 #data:信息增益最大時數據集
 16 class Tree:
 17     def __init__(self, leftBranch =None, rightBranch= None, col =-1, value =None, results =None, summary =None, data =None):
 18         self.leftBranch = leftBranch
 19         self.rightBranch = rightBranch
 20         self.col = col
 21         self.value = value
 22         self.results = results
 23         self.summary = summary
 24         self.data = data
 25         
 26     def __str__(self):
 27         print(u"列號:%d"%self.col)
 28         print(u"列划分值:%s"%self.value)
 29         print(u"樣本信息:%s"%self.summary)
 30         return ""
 31 
 32         
 33 
 34 #划分數據集
 35 def splitDataSet(dataSet, value, column):
 36     leftList=[]
 37     rightList=[]
 38     #判斷value是否是數值型
 39     if(isinstance(value, int) or isinstance(value, float)):
 40         #遍歷每一行數據
 41         for rowData in dataSet:
 42             #如果某一行指定列值>=value,則將該行數據保存在leftList中,否則保存在rightList中
 43             if(rowData[column] >= value):
 44                 leftList.append(rowData)
 45             else:
 46                 rightList.append(rowData)
 47     #value為標稱型
 48     else:
 49         #遍歷每一行數據
 50         for rowData in dataSet:
 51             #如果某一行指定列值==value,則將該行數據保存在leftList中,否則保存在rightList中
 52             if(rowData[column] == value):
 53                 leftList.append(rowData)
 54             else:
 55                 rightList.append(rowData)
 56     return leftList, rightList
 57 
 58 #統計標簽類每個樣本個數
 59 '''
 60 該函數是計算gini值的輔助函數,假設輸入的dataSet為為['A', 'B', 'C', 'A', 'A', 'D'],
 61 則輸出為['A':3,' B':1, 'C':1, 'D':1],這樣分類統計dataSet中每個類別的數量
 62 '''      
 63 def calculateDiffCount(dataSet):   
 64     results = {}
 65     for data in dataSet:
 66         # data[-1] 是數據集最后一列,也就是標簽類
 67         if data[-1] not in results:
 68             results.setdefault(data[-1], 1)
 69         else:
 70             results[data[-1]] += 1
 71     return results
 72 
 73 
 74 #基尼指數公式實現
 75 def gini(dataSet):
 76     # 計算gini的值(Calculate GINI)
 77     #數據所有行
 78     length = len(dataSet)
 79     #標簽列合並后的數據集
 80     results = calculateDiffCount(dataSet)
 81     imp = 0.0
 82     for i in results:
 83         imp += results[i] / length * results[i] / length
 84     return 1 - imp
 85 
 86 #生成決策樹
 87 '''算法步驟'''
 88 '''根據訓練數據集,從根結點開始,遞歸地對每個結點進行以下操作,構建二叉決策樹:
 89 1 設結點的訓練數據集為D,計算現有特征對該數據集的信息增益。此時,對每一個特征A,對其可能取的
 90   每個值a,根據樣本點對A >=a 的測試為“是”或“否”將D分割成D1和D2兩部分,利用基尼指數計算信息增益。
 91 2 在所有可能的特征A以及它們所有可能的切分點a中,選擇信息增益最大的特征及其對應的切分點作為最優特征
 92   與最優切分點,依據最優特征與最優切分點,從現結點生成兩個子結點,將訓練數據集依特征分配到兩個子結點中去。
 93 3 對兩個子結點遞歸地調用1,2,直至滿足停止條件。
 94 4 生成CART決策樹。
 95 '''''''''''''''''''''
 96 #evaluationFunc= gini :采用的是基尼指數來衡量信息關注度          
 97 def buildDecisionTree(dataSet, evaluationFunc = gini):
 98     #計算基礎數據集的基尼指數
 99     baseGain = evaluationFunc(dataSet)
100     #計算每一行的長度(也就是列總數)
101     columnLength = len(dataSet[0])
102     #計算數據項總數
103     rowLength = len(dataSet)
104     #初始化
105     bestGain = 0.0 #信息增益最大值
106     bestValue = None #信息增益最大時的列索引,以及划分數據集的樣本值
107     bestSet = None # 信息增益最大,聽過樣本值划分數據集后的數據子集
108     #標簽列除外(最后一列),遍歷每一列數據
109     for col in range(columnLength -1):
110         #獲取指定列數據
111         colSet = [example[col] for example in dataSet]
112         #獲取指定列樣本唯一值
113         uniqueColSet = set(colSet)
114         #遍歷指定列樣本集
115         for value in uniqueColSet: 
116             #分割數據集
117             leftDataSet, rightDataSet = splitDataSet(dataSet, value, col)
118             #計算子數據集概率,python3 "/"除號結果為小數
119             prop = len(leftDataSet)/rowLength
120             #計算信息增益
121             infoGain = baseGain - prop*evaluationFunc(leftDataSet) - (1 - prop)*evaluationFunc(rightDataSet)
122             #找出信息增益最大時的列索引,value,數據子集
123             if(infoGain > bestGain):
124                 bestGain = infoGain
125                 bestValue = (col, value)
126                 bestSet = (leftDataSet, rightDataSet)
127     #結點信息
128 #    nodeDescription = {'impurity:%.3f'%baseGain,'sample:%d'%rowLength}
129     nodeDescription = {'impurity': '%.3f' % baseGain, 'sample': '%d' % rowLength}
130     #數據行標簽類別不一致,可以繼續分類
131     #遞歸必須有終止條件
132     if bestGain > 0:
133         #遞歸,生成左子樹結點,右子樹結點
134         leftBranch = buildDecisionTree(bestSet[0], evaluationFunc)
135         rightBranch = buildDecisionTree(bestSet[1], evaluationFunc)
136         return Tree(leftBranch = leftBranch, rightBranch = rightBranch, col = bestValue[0]
137                     , value = bestValue[1], summary = nodeDescription, data = bestSet)
138     else:
139         #數據行標簽類別都相同,分類終止
140         return Tree(results = calculateDiffCount(dataSet), summary = nodeDescription, data = dataSet)
141     
142 def createTree(dataSet, evaluationFunc=gini):
143     # 遞歸建立決策樹, 當gain=0,時停止回歸
144     #計算基礎數據集的基尼指數
145     baseGain = evaluationFunc(dataSet)
146     #計算每一行的長度(也就是列總數)
147     columnLength = len(dataSet[0])
148     #計算數據項總數
149     rowLength = len(dataSet)
150     #初始化
151     bestGain = 0.0 #信息增益最大值
152     bestValue = None #信息增益最大時的列索引,以及划分數據集的樣本值
153     bestSet = None # 信息增益最大,聽過樣本值划分數據集后的數據子集
154     #標簽列除外(最后一列),遍歷每一列數據
155     for col in range(columnLength -1):
156         #獲取指定列數據
157         colSet = [example[col] for example in dataSet]
158         #獲取指定列樣本唯一值
159         uniqueColSet = set(colSet)
160         #遍歷指定列樣本集
161         for value in uniqueColSet: 
162             #分割數據集
163             leftDataSet, rightDataSet = splitDataSet(dataSet, value, col)
164             #計算子數據集概率,python3 "/"除號結果為小數
165             prop = len(leftDataSet)/rowLength
166             #計算信息增益
167             infoGain = baseGain - prop*evaluationFunc(leftDataSet) - (1 - prop)*evaluationFunc(rightDataSet)
168             #找出信息增益最大時的列索引,value,數據子集
169             if(infoGain > bestGain):
170                 bestGain = infoGain
171                 bestValue = (col, value)
172                 bestSet = (leftDataSet, rightDataSet)
173                 
174     impurity = u'%.3f' % baseGain
175     sample = '%d' % rowLength
176    
177     if bestGain > 0:                
178         bestFeatLabel =u'serial:%s\nimpurity:%s\nsample:%s'%(bestValue[0], impurity,sample) 
179         myTree = {bestFeatLabel:{}}
180         myTree[bestFeatLabel][bestValue[1]] = createTree(bestSet[0], evaluationFunc)
181         myTree[bestFeatLabel]['no'] = createTree(bestSet[1], evaluationFunc) 
182         return myTree
183     else:#遞歸需要返回值
184         bestFeatValue =u'%s\nimpurity:%s\nsample:%s'%(str(calculateDiffCount(dataSet)), impurity,sample)
185         return bestFeatValue
186     
187 #分類測試:
188 '''根據給定測試數據遍歷二叉樹,找到符合條件的葉子結點'''
189 '''例如測試數據為[5.9,3,4.2,1.75],按照訓練數據生成的決策樹分類的順序為
190    第2列對應測試數據4.2 =>與決策樹根結點(2)的value(3)比較,>=3則遍歷左子樹,否則遍歷右子樹,
191    葉子結點就是結果'''       
192 def classify(data, tree):
193     #判斷是否是葉子結點,是就返回葉子結點相關信息,否就繼續遍歷
194     if tree.results != None:
195         return u"%s\n%s"%(tree.results, tree.summary)
196     else:
197         branch = None
198         v = data[tree.col]
199         #數值型數據
200         if isinstance(v, int) or isinstance(v, float):
201             if v >= tree.value:
202                 branch = tree.leftBranch
203             else:
204                 branch = tree.rightBranch
205         else:#標稱型數據
206             if v == tree.value:
207                 branch = tree.leftBranch
208             else:
209                 branch = tree.rightBranch
210         return classify(data, branch) 
211     
212 def loadCSV(fileName):
213     def convertTypes(s):
214         s = s.strip()
215         try:
216             return float(s) if '.' in s else int(s)
217         except ValueError:
218             return s
219     data = np.loadtxt(fileName, dtype='str', delimiter=',')
220     data = data[1:, :]
221     dataSet =([[convertTypes(item) for item in row] for row in data])
222     return dataSet
223 
224 #多數表決器
225 #列中相同值數量最多為結果
226 def majorityCnt(classList):
227     import operator
228     classCounts = {}
229     for value in classList:
230         if(value not in classCounts.keys()):
231             classCounts[value] = 0
232         classCounts[value] +=1
233     sortedClassCount = sorted(classCounts.items(),key = operator.itemgetter(1),reverse =True)
234     return sortedClassCount[0][0]
235 
236 #剪枝算法(前序遍歷方式:根=>左子樹=>右子樹)
237 '''算法步驟
238 1. 從二叉樹的根結點出發,遞歸調用剪枝算法,直至左、右結點都是葉子結點
239 2. 計算父節點(子結點為葉子結點)的信息增益infoGain
240 3. 如果infoGain < miniGain,則選取樣本多的葉子結點來取代父節點
241 4. 循環1,2,3,直至遍歷完整棵樹
242 '''''''''
243 def prune(tree, miniGain, evaluationFunc = gini):
244    print(u"當前結點信息:")
245    print(str(tree))
246    #如果當前結點的左子樹不是葉子結點,遍歷左子樹
247    if(tree.leftBranch.results == None):
248        print(u"左子樹結點信息:")
249        print(str(tree.leftBranch))
250        prune(tree.leftBranch, miniGain, evaluationFunc)
251    #如果當前結點的右子樹不是葉子結點,遍歷右子樹
252    if(tree.rightBranch.results == None):
253        print(u"右子樹結點信息:")
254        print(str(tree.rightBranch))
255        prune(tree.rightBranch, miniGain, evaluationFunc)
256    #左子樹和右子樹都是葉子結點
257    if(tree.leftBranch.results != None and tree.rightBranch.results != None):
258        #計算左葉子結點數據長度
259        leftLen = len(tree.leftBranch.data)
260        #計算右葉子結點數據長度
261        rightLen = len(tree.rightBranch.data)
262        #計算左葉子結點概率
263        leftProp = leftLen/(leftLen + rightLen)
264        #計算該結點的信息增益(子類是葉子結點)
265        infoGain = (evaluationFunc(tree.leftBranch.data + tree.rightBranch.data) - 
266                    leftProp*evaluationFunc(tree.leftBranch.data) - (1 - leftProp)*evaluationFunc(tree.rightBranch.data))
267        #信息增益 < 給定閾值,則說明葉子結點與其父結點特征差別不大,可以剪枝
268        if(infoGain < miniGain):
269            #合並左右葉子結點數據
270            dataSet = tree.leftBranch.data + tree.rightBranch.data
271            #獲取標簽列
272            classLabels = [example[-1] for example in dataSet]
273            #找到樣本最多的標簽值
274            keyLabel = majorityCnt(classLabels)
275            #判斷標簽值是左右葉子結點哪一個
276            if keyLabel in tree.leftBranch.results:
277                #左葉子結點取代父結點
278                tree.data = tree.leftBranch.data
279                tree.results = tree.leftBranch.results
280                tree.summary = tree.leftBranch.summary
281            else:
282                #右葉子結點取代父結點
283                tree.data = tree.rightBranch.data
284                tree.results = tree.rightBranch.results
285                tree.summary = tree.rightBranch.summary
286            tree.leftBranch = None
287            tree.rightBranch = None
288                
289                
290        
復制代碼
復制代碼
  1 '''
  2 Created on Oct 14, 2010
  3 
  4 @author: Peter Harrington
  5 '''
  6 import matplotlib.pyplot as plt
  7 
  8 decisionNode = dict(boxstyle="sawtooth", fc="0.8")
  9 leafNode = dict(boxstyle="circle", fc="0.7")
 10 arrow_args = dict(arrowstyle="<-")
 11 
 12 #獲取樹的葉子節點
 13 def getNumLeafs(myTree):
 14     numLeafs = 0
 15     #dict轉化為list
 16     firstSides = list(myTree.keys())
 17     firstStr = firstSides[0]
 18     secondDict = myTree[firstStr]
 19     for key in secondDict.keys():
 20         #判斷是否是葉子節點(通過類型判斷,子類不存在,則類型為str;子類存在,則為dict)
 21         if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
 22             numLeafs += getNumLeafs(secondDict[key])
 23         else:   numLeafs +=1
 24     return numLeafs
 25 
 26 #獲取樹的層數
 27 def getTreeDepth(myTree):
 28     maxDepth = 0
 29     #dict轉化為list
 30     firstSides = list(myTree.keys())
 31     firstStr = firstSides[0]
 32     secondDict = myTree[firstStr]
 33     for key in secondDict.keys():
 34         if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
 35             thisDepth = 1 + getTreeDepth(secondDict[key])
 36         else:   thisDepth = 1
 37         if thisDepth > maxDepth: maxDepth = thisDepth
 38     return maxDepth
 39 
 40 def plotNode(nodeTxt, centerPt, parentPt, nodeType):
 41     createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
 42              xytext=centerPt, textcoords='axes fraction',
 43              va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
 44     
 45 def plotMidText(cntrPt, parentPt, txtString):
 46     xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
 47     yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
 48     createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
 49 
 50 def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
 51     numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
 52     depth = getTreeDepth(myTree)
 53     firstSides = list(myTree.keys())
 54     firstStr = firstSides[0] #the text label for this node should be this         
 55     cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
 56     plotMidText(cntrPt, parentPt, nodeTxt)
 57     plotNode(firstStr, cntrPt, parentPt, decisionNode)
 58     secondDict = myTree[firstStr]
 59     plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
 60     for key in secondDict.keys():
 61         if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes   
 62             plotTree(secondDict[key],cntrPt,str(key))        #recursion
 63         else:   #it's a leaf node print the leaf node
 64             plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
 65             plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
 66             plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
 67     plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
 68 #if you do get a dictonary you know it's a tree, and the first element will be another dict
 69 #繪制決策樹 樣例1
 70 def createPlot(inTree):
 71     fig = plt.figure(1, facecolor='white')
 72     fig.clf()
 73     axprops = dict(xticks=[], yticks=[])
 74     createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
 75     #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
 76     #寬,高間距
 77     plotTree.totalW = float(getNumLeafs(inTree))-3
 78     plotTree.totalD = float(getTreeDepth(inTree))-2
 79 #    plotTree.totalW = float(getNumLeafs(inTree))
 80 #    plotTree.totalD = float(getTreeDepth(inTree))
 81     plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
 82     plotTree(inTree, (0.95,1.0), '')
 83     plt.show()
 84     
 85 #繪制決策樹 樣例2
 86 def createPlot1(inTree):
 87     fig = plt.figure(1, facecolor='white')
 88     fig.clf()
 89     axprops = dict(xticks=[], yticks=[])
 90     createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
 91     #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
 92     #寬,高間距
 93     plotTree.totalW = float(getNumLeafs(inTree))-4.5
 94     plotTree.totalD = float(getTreeDepth(inTree)) -3
 95     plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
 96     plotTree(inTree, (1.0,1.0), '')
 97     plt.show()
 98 
 99 #繪制樹的根節點和葉子節點(根節點形狀:長方形,葉子節點:橢圓形)
100 #def createPlot():
101 #    fig = plt.figure(1, facecolor='white')
102 #    fig.clf()
103 #    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
104 #    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
105 #    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
106 #    plt.show()
107 
108 def retrieveTree(i):
109     listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
110                   {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
111                   ]
112     return listOfTrees[i]
113 
114 #thisTree = retrieveTree(0)
115 #createPlot(thisTree)
116 #createPlot() 
117 #myTree = retrieveTree(0)
118 #numLeafs =getNumLeafs(myTree)
119 #treeDepth =getTreeDepth(myTree)
120 #print(u"葉子節點數目:%d"% numLeafs)
121 #print(u"樹深度:%d"%treeDepth)
復制代碼
復制代碼
 1 # -*- coding: utf-8 -*-
 2 """
 3 Created on Wed Aug 15 14:16:59 2018
 4 
 5 @author: weixw
 6 """
 7 import myCart as mc
 8 if __name__ == '__main__':
 9     import treePlotter as tp
10     dataSet = mc.loadCSV("dataSet.csv")
11     myTree = mc.createTree(dataSet, evaluationFunc=gini)
12     print(u"myTree:%s"%myTree)
13     #繪制決策樹
14     print(u"繪制決策樹:")
15     tp.createPlot1(myTree)
16     decisionTree = mc.buildDecisionTree(dataSet, evaluationFunc=gini)
17     testData = [5.9,3,4.2,1.75]
18     r = mc.classify(testData, decisionTree)
19     print(u"分類后測試結果:")
20     print(r)
21     print()
22     mc.prune(decisionTree, 0.4)   
23     r1 = mc.classify(testData, decisionTree)
24     print(u"剪枝后測試結果:")
25     print(r1)
復制代碼

 

運行結果

為什么我要再寫個createTree(dataSet, evaluationFunc=gini)函數,是因為繪制決策樹createPlot1(myTree)輸入參數需要是json結構數據。

 

將生成的決策樹變為可視圖形,這樣更直觀。

當然,也可以將自定義樹對象信息打印出來,我在代碼里已加入打印語句。

打印結果如下,因為屏幕的原因,沒有全部粘貼出來,大家可以對照決策樹繪制圖,這樣可以相互印證,加深理解。

 

 

在未做剪枝處理時的分類測試結果如下:

 

剪枝處理后的分類測試結果:

可以看出,{'versicolor': 47}取代了父結點serial:3,成為新的葉子結點。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM