前一天,我們基於sklearn科學庫實現了ID3的決策樹程序,本文將基於python自帶庫實現ID3決策樹算法。
一、代碼涉及基本知識
1、 為了繪圖方便,引入了一個第三方treePlotter模塊進行圖形繪制。該模塊使用方法簡單,調用模塊createPlot接口,傳入一個樹型結構對象,即可繪制出相應圖像。
2、 在python中,如何定義一個樹型結構對象
可以使用了python自帶的字典數據類型來定義一個樹型對象。例如下面代碼,我們定義一個根節點和兩個左右子節點:
rootNode = {'rootNode': {}} leftNode = {'leftNode': {'yes':'yes'}} rightNode = {'rightNode': {'no':'no'}} rootNode['rootNode']['left'] = leftNode rootNode['rootNode']['right'] = rightNode treePlotter.createPlot(rootNode)
通過調用treePlotter模塊,繪制出如下樹的圖像
2、 遞歸調用
為了求每個節點的各個子節點,要用到遞歸的方法來實現,基本思想和二叉樹的遍歷方法一致,后面我們還會用Python實現一個二叉樹源碼,此處不再進行介紹。
3、 此外,還需要對python常用的數據類型及其操作比較了解,例如字典、列表、集合等
二、程序主要流程
三、測試數據集
age |
income |
student |
credit_rating |
class_buys_computer |
youth |
high |
no |
fair |
no |
youth |
high |
no |
excellent |
no |
middle_aged |
high |
no |
fair |
yes |
senior |
medium |
no |
fair |
yes |
senior |
low |
yes |
fair |
yes |
senior |
low |
yes |
excellent |
no |
middle_aged |
low |
yes |
excellent |
yes |
youth |
medium |
no |
fair |
no |
youth |
low |
yes |
fair |
yes |
senior |
medium |
yes |
fair |
yes |
youth |
medium |
yes |
excellent |
yes |
middle_aged |
medium |
no |
excellent |
yes |
middle_aged |
high |
yes |
fair |
yes |
senior |
medium |
no |
excellent |
no |
四、程序代碼
1、計算測試集熵及信息增益
# 求最優的根節點 def chooseBestFeatureToSplit(dataset,headerList): # 定義一個初始值 bestInfoGainRate = 0.0 bestFeature = 0 # 求特征列項的數量 numFeatures = len(dataset[0]) -1 # 獲取整個測試數據集的熵 baseShnnonEnt = calcShannonEnt(dataset) print("total's shannonEnt = %f" % (baseShnnonEnt)) # 遍歷每一個特征列,求取信息增益 for i in range(numFeatures): # 獲取某一列所有特征值 featureList = [example[i] for example in dataset] uniqueVals = set(featureList) newEntropy = 0.0 # 求得某一列某一個特征值的概率和熵 newShannonEnt = 0.0 for value in uniqueVals: # 計算熵 subDataset = splitDataSet(dataset,i,value) newEntropy = calcShannonEnt(subDataset) # 計算某一列某一個特征值的概率 newProbability = len(subDataset) / float(len(dataset)) newShannonEnt += newProbability*calcShannonEnt(subDataset) infoGainRate = baseShnnonEnt - newShannonEnt print("%s's infoGainRate = %f - %f = %f"%(headerList[i],baseShnnonEnt,newShannonEnt,infoGainRate)) if infoGainRate > bestInfoGainRate: bestInfoGainRate = infoGainRate bestFeature = i return bestFeature
該結果和前一天計算結果一致,age特征對應信息增益最大,因此設為根節點:
2、程序源碼
treePlotter.py
import matplotlib.pyplot as plt # 定義決策樹決策結果屬性 descisionNode = dict(boxstyle='sawtooth', fc='0.8') leafNode = dict(boxstyle='round4', fc='0.8') arrow_args = dict(arrowstyle='<-') def plotNode(nodeTxt, centerPt, parentPt, nodeType): # nodeTxt為要顯示的文本,centerNode為文本中心點, nodeType為箭頭所在的點, parentPt為指向文本的點 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va='center', ha='center', bbox=nodeType, arrowprops=arrow_args) def getNumLeafs(myTree): numLeafs = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafs def getTreeDepth(myTree): maxDepth = 0 firstStr = list(myTree.keys())[0] # 這個是改的地方,原來myTree.keys()返回的是dict_keys類,不是列表,運行會報錯。有好幾個地方這樣 secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = {'xticks': None, 'yticks': None} createPlot.ax1 = plt.subplot(111, frameon=False) plotTree.totalW = float(getNumLeafs(inTree)) # 全局變量寬度 = 葉子數目 plotTree.totalD = float(getTreeDepth(inTree)) # 全局變量高度 = 深度 plotTree.xOff = -0.5/plotTree.totalW plotTree.yOff = 1.0 plotTree(inTree, (0.5, 1.0), '') plt.show() def plotTree(myTree, parentPt, nodeTxt): numLeafs = getNumLeafs(myTree) depth = getTreeDepth(myTree) firstStr = list(myTree.keys())[0] # cntrPt文本中心點, parentPt指向文本中心的點 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, descisionNode) seconDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in seconDict.keys(): if type(seconDict[key]).__name__ == 'dict': plotTree(seconDict[key], cntrPt, str(key)) else: plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(seconDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va='center', ha='center', rotation=30)
decision_tree_ID3.py
# 導入庫 import csv import math import operator import treePlotter # 導入數據集 def readDataset(file_path,file_mode): allElectronicsData = open(file_path, file_mode) reader = csv.reader(allElectronicsData) # 讀取特征名稱 headers = next(reader) # 讀取測試數據集 dataset = [] for row in reader: dataset.append(row) return headers,dataset # 求某個數據集的熵 def calcShannonEnt(dataset): shannonEnt = 0.0 labelList = {} for vec_now in dataset: labelValue = vec_now[-1] if vec_now[-1] not in labelList.keys(): labelList[labelValue] = 0 labelList[labelValue] += 1 for labelKey in labelList: probability = float(labelList[labelKey] / len(dataset)) shannonEnt -= probability*math.log(probability,2) return shannonEnt # 根據給定的列特征值,分理出給定的特征量 def splitDataSet(dataset,feature_seq,value): new_dataset = [] for vec_row in dataset: feature_Value = vec_row[feature_seq] if feature_Value == value: temp_vec = [] temp_vec = vec_row[:feature_seq] temp_vec.extend(vec_row[feature_seq+1:]) new_dataset.append(temp_vec) return new_dataset # 求最優的根節點 def chooseBestFeatureToSplit(dataset,headerList): # 定義一個初始值 bestInfoGainRate = 0.0 bestFeature = 0 # 求特征列項的數量 numFeatures = len(dataset[0]) -1 # 獲取整個測試數據集的熵 baseShnnonEnt = calcShannonEnt(dataset) #print("total's shannonEnt = %f" % (baseShnnonEnt)) # 遍歷每一個特征列,求取信息增益 for i in range(numFeatures): # 獲取某一列所有特征值 featureList = [example[i] for example in dataset] uniqueVals = set(featureList) newEntropy = 0.0 # 求得某一列某一個特征值的概率和熵 newShannonEnt = 0.0 for value in uniqueVals: # 計算熵 subDataset = splitDataSet(dataset,i,value) newEntropy = calcShannonEnt(subDataset) # 計算某一列某一個特征值的概率 newProbability = len(subDataset) / float(len(dataset)) newShannonEnt += newProbability*calcShannonEnt(subDataset) infoGainRate = baseShnnonEnt - newShannonEnt #print("%s's infoGainRate = %f - %f = %f"%(headerList[i],baseShnnonEnt,newShannonEnt,infoGainRate)) if infoGainRate > bestInfoGainRate: bestInfoGainRate = infoGainRate bestFeature = i return bestFeature # 標簽判定,通過少數服從多數原則 def majorityCnt(classList): classcount = {} for cl in classList: if cl not in classcount.keys(): classcount[cl] = 0 classcount[cl] += 1 sortedClassCount = sorted(classcount.items(),key = operator.itemgetter(1),reverse= True) return sortedClassCount[0][0] # 創建一個決策樹 def createTree(dataSet, labels): classList = [example[-1] for example in dataSet] # 1 所有特征值都是相同的時候直接返回 if classList.count(classList[0]) == len(classList): return classList[0] # 2 遍歷完所有特征值,投票原則,返回出現次數最多的標簽 if len(dataSet[0]) == 1: return majorityCnt(classList) # 3 如果不滿足上面兩者,求最優特征 bestFeature = chooseBestFeatureToSplit(dataSet,labels) bestFeatureLabel = labels[bestFeature] myTree = {bestFeatureLabel: {}} del (labels[bestFeature]) featurValues = [example[bestFeature] for example in dataSet] uniqueVals = set(featurValues) # 使用遞歸的方法,獲得整個樹 for value in uniqueVals: subLabels = labels[:] myTree[bestFeatureLabel][value] = createTree(splitDataSet(dataSet, bestFeature, value), subLabels) return myTree def classify(inputTree, featLabels, testVec): firstStr = list(inputTree.keys())[0] secondDict = inputTree[firstStr] featIndex = featLabels.index(firstStr) for key in secondDict.keys(): if testVec[featIndex] == key: if type(secondDict[key]).__name__ == 'dict': classLabel = classify(secondDict[key], featLabels, testVec) else: classLabel = secondDict[key] return classLabel def classifyAll(inputTree, featLabels, testDataSet): classLabelAll = [] for testVec in testDataSet: classLabelAll.append(classify(inputTree, featLabels, testVec)) return classLabelAll def storeTree(inputTree, filename): import pickle fw = open(filename, 'wb') pickle.dump(inputTree, fw) fw.close() def grabTree(filename): import pickle fr = open(filename, 'rb') return pickle.load(fr) def main(): # 讀取數據集 labels, dataSet = readDataset(file_path=r'D:\test.csv', file_mode='r') labels_tmp = labels[:] # 拷貝,createTree會改變labels desicionTree = createTree(dataSet, labels_tmp) storeTree(desicionTree, 'classifierStorage.txt') desicionTree = grabTree('classifierStorage.txt') treePlotter.createPlot(desicionTree) testSet = [['youth', 'high', 'no', 'fair', 'no']] print('classifyResult:\n', classifyAll(desicionTree, labels, testSet)) if __name__ == '__main__': main()
五、測試結果及結論
我們從上面求解信息增益的公式中,其實可以看出,信息增益准則其實是對可取值數目較多的屬性有所偏好!
現在假如我們把數據集中的“編號”也作為一個候選划分屬性。我們可以算出“編號”的信息增益是0.998
因為每一個樣本的編號都是不同的(由於編號獨特唯一,條件熵為0了,每一個結點中只有一類,純度非常高啊),也就是說,來了一個預測樣本,你只要告訴我編號,其它特征就沒有用了,這樣生成的決策樹顯然不具有泛化能力。
參考鏈接:
http://www.cnblogs.com/wsine/p/5180310.html
https://zhuanlan.zhihu.com/p/26760551