2019-12-03 15:31:13
數據集類型為:
vhigh,vhigh,2,2,small,low,unacc
vhigh,vhigh,2,2,small,med,unacc
vhigh,vhigh,2,2,small,high,unacc
vhigh,vhigh,2,2,med,low,unacc
vhigh,vhigh,2,2,med,med,unacc
vhigh,vhigh,2,2,med,high,unacc
vhigh,vhigh,2,2,big,low,unacc
vhigh,vhigh,2,2,big,med,unacc
vhigh,vhigh,2,2,big,high,unacc
vhigh,vhigh,2,4,small,low,unacc
vhigh,vhigh,2,4,small,med,unacc
vhigh,vhigh,2,4,small,high,unacc
vhigh,vhigh,2,4,med,low,unacc
具體的就不一一列出了,需要原數據集的可以評論
參考了https://www.cnblogs.com/wsine/p/5180310.html
剪枝前
5 from math import log 6 import operator 7 import treeplotter 8 import pandas as pd 9 import numpy as np 10 11 def calcShannonEnt(dataSet): 12 """ 13 輸入:數據集 14 輸出:數據集的香農熵 15 描述:計算給定數據集的香農熵 16 """ 17 numEntries = len(dataSet) 18 labelCounts = {} 19 for featVec in dataSet: 20 currentLabel = featVec[-1] 21 if currentLabel not in labelCounts.keys(): 22 labelCounts[currentLabel] = 0 23 labelCounts[currentLabel] += 1 24 shannonEnt = 0.0 25 for key in labelCounts: 26 prob = float(labelCounts[key])/numEntries 27 shannonEnt -= prob * log(prob, 2) 28 return shannonEnt 29 30 def splitDataSet(dataSet, axis, value): 31 """ 32 輸入:數據集,選擇維度,選擇值 33 輸出:划分數據集 34 描述:按照給定特征划分數據集;去除選擇維度中等於選擇值的項 35 """ 36 retDataSet = [] 37 for featVec in dataSet: 38 if featVec[axis] == value: 39 reduceFeatVec = featVec[:axis] 40 reduceFeatVec.extend(featVec[axis+1:]) 41 retDataSet.append(reduceFeatVec) 42 return retDataSet 43 44 def chooseBestFeatureToSplit(dataSet): 45 """ 46 輸入:數據集 47 輸出:最好的划分維度 48 描述:選擇最好的數據集划分維度 49 """ 50 numFeatures = len(dataSet[0]) - 1 51 baseEntropy = calcShannonEnt(dataSet) 52 bestInfoGain = 0.0 53 bestFeature = -1 54 for i in range(numFeatures): 55 featList = [example[i] for example in dataSet] 56 uniqueVals = set(featList) 57 newEntropy = 0.0 58 for value in uniqueVals: 59 subDataSet = splitDataSet(dataSet, i, value) 60 prob = len(subDataSet)/float(len(dataSet)) 61 newEntropy += prob * calcShannonEnt(subDataSet) 62 infoGain = baseEntropy - newEntropy 63 if (infoGain > bestInfoGain): 64 bestInfoGain = infoGain 65 bestFeature = i 66 return bestFeature 67 68 def majorityCnt(classList): 69 """ 70 輸入:分類類別列表 71 輸出:子節點的分類 72 描述:數據集已經處理了所有屬性,但是類標簽依然不是唯一的, 73 采用多數判決的方法決定該子節點的分類 74 """ 75 classCount = {} 76 for vote in classList: 77 if vote not in classCount.keys(): 78 classCount[vote] = 0 79 classCount[vote] += 1 80 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reversed=True) 81 return sortedClassCount[0][0] 82 83 def createTree(dataSet, labels): 84 """ 85 輸入:數據集,特征標簽 86 輸出:決策樹 87 描述:遞歸構建決策樹,利用上述的函數 88 """ 89 classList = [example[-1] for example in dataSet] 90 if classList.count(classList[0]) == len(classList): 91 # 類別完全相同,停止划分 92 return classList[0] 93 if len(dataSet[0]) == 1: 94 # 遍歷完所有特征時返回出現次數最多的 95 return majorityCnt(classList) 96 bestFeat = chooseBestFeatureToSplit(dataSet) 97 bestFeatLabel = labels[bestFeat] 98 myTree = {bestFeatLabel:{}} 99 del(labels[bestFeat]) 100 # 得到列表包括節點所有的屬性值 101 featValues = [example[bestFeat] for example in dataSet] 102 uniqueVals = set(featValues) 103 for value in uniqueVals: 104 subLabels = labels[:] 105 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) 106 return myTree 107 108 def classify(inputTree, featLabels, testVec): 109 """ 110 輸入:決策樹,分類標簽,測試數據 111 輸出:決策結果 112 描述:跑決策樹 113 """ 114 firstStr = list(inputTree.keys())[0] 115 secondDict = inputTree[firstStr] 116 featIndex = featLabels.index(firstStr) 117 for key in secondDict.keys(): 118 if testVec[featIndex] == key: 119 if type(secondDict[key]).__name__ == 'dict': 120 classLabel = classify(secondDict[key], featLabels, testVec) 121 else: 122 classLabel = secondDict[key] 123 return classLabel 124 125 def classifyAll(inputTree, featLabels, testDataSet): 126 """ 127 輸入:決策樹,分類標簽,測試數據集 128 輸出:決策結果 129 描述:跑決策樹 130 """ 131 classLabelAll = [] 132 for testVec in testDataSet: 133 classLabelAll.append(classify(inputTree, featLabels, testVec)) 134 return classLabelAll 135 136 def storeTree(inputTree, filename): 137 """ 138 輸入:決策樹,保存文件路徑 139 輸出: 140 描述:保存決策樹到文件 141 """ 142 import pickle 143 fw = open(filename, 'wb') 144 pickle.dump(inputTree, fw) 145 fw.close() 146 147 def grabTree(filename): 148 """ 149 輸入:文件路徑名 150 輸出:決策樹 151 描述:從文件讀取決策樹 152 """ 153 import pickle 154 fr = open(filename, 'rb') 155 return pickle.load(fr) 156 157 def createDataSet(): 158 data = pd.read_csv("car.csv") 159 train_data1=(data.replace('5more',6)).values 160 train_data = np.array(train_data1) # np.ndarray() 161 dataSet = train_data.tolist() # list 162 print(dataSet) 163 164 labels = ['buying', 'maint', 'doors', 'persons', 'lug_boot', 'safety'] 165 return dataSet, labels 166 167 168 def main(): 169 dataSet, labels = createDataSet() 170 labels_tmp = labels[:] # 拷貝,createTree會改變labels 171 desicionTree = createTree(dataSet, labels_tmp) 172 #storeTree(desicionTree, 'classifierStorage.txt') 173 #desicionTree = grabTree('classifierStorage.txt') 174 print('desicionTree:\n', desicionTree) 175 treeplotter.createPlot(desicionTree) 176 177 178 if __name__ == '__main__': 179 main()
