關於CarEvaluation 數據集決策樹的可視化(附代碼)


2019-12-03 15:31:13

數據集類型為:

vhigh,vhigh,2,2,small,low,unacc
vhigh,vhigh,2,2,small,med,unacc
vhigh,vhigh,2,2,small,high,unacc
vhigh,vhigh,2,2,med,low,unacc
vhigh,vhigh,2,2,med,med,unacc
vhigh,vhigh,2,2,med,high,unacc
vhigh,vhigh,2,2,big,low,unacc
vhigh,vhigh,2,2,big,med,unacc
vhigh,vhigh,2,2,big,high,unacc
vhigh,vhigh,2,4,small,low,unacc
vhigh,vhigh,2,4,small,med,unacc
vhigh,vhigh,2,4,small,high,unacc
vhigh,vhigh,2,4,med,low,unacc

具體的就不一一列出了,需要原數據集的可以評論

參考了https://www.cnblogs.com/wsine/p/5180310.html

剪枝前

  5 from math import log
  6 import operator
  7 import treeplotter
  8 import pandas as pd
  9 import numpy as np
 10 
 11 def calcShannonEnt(dataSet):
 12     """
 13     輸入:數據集
 14     輸出:數據集的香農熵
 15     描述:計算給定數據集的香農熵
 16     """
 17     numEntries = len(dataSet)
 18     labelCounts = {}
 19     for featVec in dataSet:
 20         currentLabel = featVec[-1]
 21         if currentLabel not in labelCounts.keys():
 22             labelCounts[currentLabel] = 0
 23         labelCounts[currentLabel] += 1
 24     shannonEnt = 0.0
 25     for key in labelCounts:
 26         prob = float(labelCounts[key])/numEntries
 27         shannonEnt -= prob * log(prob, 2)
 28     return shannonEnt
 29 
 30 def splitDataSet(dataSet, axis, value):
 31     """
 32     輸入:數據集,選擇維度,選擇值
 33     輸出:划分數據集
 34     描述:按照給定特征划分數據集;去除選擇維度中等於選擇值的項
 35     """
 36     retDataSet = []
 37     for featVec in dataSet:
 38         if featVec[axis] == value:
 39             reduceFeatVec = featVec[:axis]
 40             reduceFeatVec.extend(featVec[axis+1:])
 41             retDataSet.append(reduceFeatVec)
 42     return retDataSet
 43 
 44 def chooseBestFeatureToSplit(dataSet):
 45     """
 46     輸入:數據集
 47     輸出:最好的划分維度
 48     描述:選擇最好的數據集划分維度
 49     """
 50     numFeatures = len(dataSet[0]) - 1
 51     baseEntropy = calcShannonEnt(dataSet)
 52     bestInfoGain = 0.0
 53     bestFeature = -1
 54     for i in range(numFeatures):
 55         featList = [example[i] for example in dataSet]
 56         uniqueVals = set(featList)
 57         newEntropy = 0.0
 58         for value in uniqueVals:
 59             subDataSet = splitDataSet(dataSet, i, value)
 60             prob = len(subDataSet)/float(len(dataSet))
 61             newEntropy += prob * calcShannonEnt(subDataSet)
 62         infoGain = baseEntropy - newEntropy
 63         if (infoGain > bestInfoGain):
 64             bestInfoGain = infoGain
 65             bestFeature = i
 66     return bestFeature
 67 
 68 def majorityCnt(classList):
 69     """
 70     輸入:分類類別列表
 71     輸出:子節點的分類
 72     描述:數據集已經處理了所有屬性,但是類標簽依然不是唯一的,
 73           采用多數判決的方法決定該子節點的分類
 74     """
 75     classCount = {}
 76     for vote in classList:
 77         if vote not in classCount.keys():
 78             classCount[vote] = 0
 79         classCount[vote] += 1
 80     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reversed=True)
 81     return sortedClassCount[0][0]
 82 
 83 def createTree(dataSet, labels):
 84     """
 85     輸入:數據集,特征標簽
 86     輸出:決策樹
 87     描述:遞歸構建決策樹,利用上述的函數
 88     """
 89     classList = [example[-1] for example in dataSet]
 90     if classList.count(classList[0]) == len(classList):
 91         # 類別完全相同,停止划分
 92         return classList[0]
 93     if len(dataSet[0]) == 1:
 94         # 遍歷完所有特征時返回出現次數最多的
 95         return majorityCnt(classList)
 96     bestFeat = chooseBestFeatureToSplit(dataSet)
 97     bestFeatLabel = labels[bestFeat]
 98     myTree = {bestFeatLabel:{}}
 99     del(labels[bestFeat])
100     # 得到列表包括節點所有的屬性值
101     featValues = [example[bestFeat] for example in dataSet]
102     uniqueVals = set(featValues)
103     for value in uniqueVals:
104         subLabels = labels[:]
105         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
106     return myTree
107 
108 def classify(inputTree, featLabels, testVec):
109     """
110     輸入:決策樹,分類標簽,測試數據
111     輸出:決策結果
112     描述:跑決策樹
113     """
114     firstStr = list(inputTree.keys())[0]
115     secondDict = inputTree[firstStr]
116     featIndex = featLabels.index(firstStr)
117     for key in secondDict.keys():
118         if testVec[featIndex] == key:
119             if type(secondDict[key]).__name__ == 'dict':
120                 classLabel = classify(secondDict[key], featLabels, testVec)
121             else:
122                 classLabel = secondDict[key]
123     return classLabel
124 
125 def classifyAll(inputTree, featLabels, testDataSet):
126     """
127     輸入:決策樹,分類標簽,測試數據集
128     輸出:決策結果
129     描述:跑決策樹
130     """
131     classLabelAll = []
132     for testVec in testDataSet:
133         classLabelAll.append(classify(inputTree, featLabels, testVec))
134     return classLabelAll
135 
136 def storeTree(inputTree, filename):
137     """
138     輸入:決策樹,保存文件路徑
139     輸出:
140     描述:保存決策樹到文件
141     """
142     import pickle
143     fw = open(filename, 'wb')
144     pickle.dump(inputTree, fw)
145     fw.close()
146 
147 def grabTree(filename):
148     """
149     輸入:文件路徑名
150     輸出:決策樹
151     描述:從文件讀取決策樹
152     """
153     import pickle
154     fr = open(filename, 'rb')
155     return pickle.load(fr)
156 
157 def createDataSet():
158     data = pd.read_csv("car.csv")
159     train_data1=(data.replace('5more',6)).values
160     train_data = np.array(train_data1)  # np.ndarray()
161     dataSet = train_data.tolist()  # list
162     print(dataSet)
163 
164     labels = ['buying', 'maint', 'doors', 'persons', 'lug_boot', 'safety']
165     return dataSet, labels
166 
167 
168 def main():
169     dataSet, labels = createDataSet()
170     labels_tmp = labels[:] # 拷貝,createTree會改變labels
171     desicionTree = createTree(dataSet, labels_tmp)
172     #storeTree(desicionTree, 'classifierStorage.txt')
173     #desicionTree = grabTree('classifierStorage.txt')
174     print('desicionTree:\n', desicionTree)
175     treeplotter.createPlot(desicionTree)
176 
177 
178 if __name__ == '__main__':
179     main()


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM