day-8 python自帶庫實現ID3決策樹算法


 

  前一天,我們基於sklearn科學庫實現了ID3的決策樹程序,本文將基於python自帶庫實現ID3決策樹算法。

 一、代碼涉及基本知識

  1、 為了繪圖方便,引入了一個第三方treePlotter模塊進行圖形繪制。該模塊使用方法簡單,調用模塊createPlot接口,傳入一個樹型結構對象,即可繪制出相應圖像。

  2、  在python中,如何定義一個樹型結構對象

    可以使用了python自帶的字典數據類型來定義一個樹型對象。例如下面代碼,我們定義一個根節點和兩個左右子節點:

    rootNode = {'rootNode': {}}
    leftNode = {'leftNode': {'yes':'yes'}}
    rightNode = {'rightNode': {'no':'no'}}
    rootNode['rootNode']['left'] = leftNode
    rootNode['rootNode']['right'] = rightNode
    treePlotter.createPlot(rootNode)

    通過調用treePlotter模塊,繪制出如下樹的圖像

    

  2、  遞歸調用

    為了求每個節點的各個子節點,要用到遞歸的方法來實現,基本思想和二叉樹的遍歷方法一致,后面我們還會用Python實現一個二叉樹源碼,此處不再進行介紹。

  3、  此外,還需要對python常用的數據類型及其操作比較了解,例如字典、列表、集合等

二、程序主要流程

 

 

 

三、測試數據集

age

income

student

credit_rating

class_buys_computer

youth

high

no

fair

no

youth

high

no

excellent

no

middle_aged

high

no

fair

yes

senior

medium

no

fair

yes

senior

low

yes

fair

yes

senior

low

yes

excellent

no

middle_aged

low

yes

excellent

yes

youth

medium

no

fair

no

youth

low

yes

fair

yes

senior

medium

yes

fair

yes

youth

medium

yes

excellent

yes

middle_aged

medium

no

excellent

yes

middle_aged

high

yes

fair

yes

senior

medium

no

excellent

no

四、程序代碼

         1、計算測試集熵及信息增益        

# 求最優的根節點
def chooseBestFeatureToSplit(dataset,headerList):
    # 定義一個初始值
    bestInfoGainRate = 0.0
    bestFeature = 0
    # 求特征列項的數量
    numFeatures = len(dataset[0]) -1
    # 獲取整個測試數據集的熵
    baseShnnonEnt = calcShannonEnt(dataset)
    print("total's shannonEnt = %f" % (baseShnnonEnt))
    # 遍歷每一個特征列,求取信息增益
    for i in range(numFeatures):
        # 獲取某一列所有特征值
        featureList = [example[i] for example in dataset]
        uniqueVals = set(featureList)
        newEntropy = 0.0
        # 求得某一列某一個特征值的概率和熵
        newShannonEnt = 0.0
        for value in uniqueVals:
            # 計算熵
            subDataset = splitDataSet(dataset,i,value)
            newEntropy = calcShannonEnt(subDataset)
            # 計算某一列某一個特征值的概率
            newProbability = len(subDataset) / float(len(dataset))
            newShannonEnt += newProbability*calcShannonEnt(subDataset)
        infoGainRate = baseShnnonEnt - newShannonEnt
        print("%s's infoGainRate = %f - %f = %f"%(headerList[i],baseShnnonEnt,newShannonEnt,infoGainRate))
        if infoGainRate > bestInfoGainRate:
            bestInfoGainRate = infoGainRate
            bestFeature = i
    return bestFeature

  該結果和前一天計算結果一致,age特征對應信息增益最大,因此設為根節點:

        

         2、程序源碼

         treePlotter.py        

import matplotlib.pyplot as plt

# 定義決策樹決策結果屬性
descisionNode = dict(boxstyle='sawtooth', fc='0.8')
leafNode = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle='<-')

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    # nodeTxt為要顯示的文本,centerNode為文本中心點, nodeType為箭頭所在的點, parentPt為指向文本的點
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                             xytext=centerPt, textcoords='axes fraction',
                              va='center', ha='center', bbox=nodeType, arrowprops=arrow_args)
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]     # 這個是改的地方,原來myTree.keys()返回的是dict_keys類,不是列表,運行會報錯。有好幾個地方這樣
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = {'xticks': None, 'yticks': None}
    createPlot.ax1 = plt.subplot(111, frameon=False)
    plotTree.totalW = float(getNumLeafs(inTree))     # 全局變量寬度 = 葉子數目
    plotTree.totalD = float(getTreeDepth(inTree))     # 全局變量高度 = 深度
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    # cntrPt文本中心點, parentPt指向文本中心的點
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, descisionNode)
    seconDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in seconDict.keys():
        if type(seconDict[key]).__name__ == 'dict':
            plotTree(seconDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(seconDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va='center', ha='center', rotation=30)

         decision_tree_ID3.py

# 導入庫
import csv
import math
import operator
import treePlotter


# 導入數據集
def readDataset(file_path,file_mode):
    allElectronicsData = open(file_path, file_mode)
    reader = csv.reader(allElectronicsData)
    # 讀取特征名稱
    headers = next(reader)
    # 讀取測試數據集
    dataset = []
    for row in reader:
        dataset.append(row)
    return headers,dataset

# 求某個數據集的熵
def calcShannonEnt(dataset):
    shannonEnt = 0.0
    labelList = {}
    for vec_now in dataset:
        labelValue = vec_now[-1]
        if vec_now[-1] not in labelList.keys():
            labelList[labelValue] = 0
        labelList[labelValue] += 1
    for labelKey in labelList:
        probability = float(labelList[labelKey] / len(dataset))
        shannonEnt -= probability*math.log(probability,2)
    return shannonEnt

# 根據給定的列特征值,分理出給定的特征量
def splitDataSet(dataset,feature_seq,value):
    new_dataset = []
    for vec_row in dataset:
        feature_Value = vec_row[feature_seq]
        if feature_Value == value:
            temp_vec = []
            temp_vec = vec_row[:feature_seq]
            temp_vec.extend(vec_row[feature_seq+1:])
            new_dataset.append(temp_vec)
    return new_dataset

# 求最優的根節點
def chooseBestFeatureToSplit(dataset,headerList):
    # 定義一個初始值
    bestInfoGainRate = 0.0
    bestFeature = 0
    # 求特征列項的數量
    numFeatures = len(dataset[0]) -1
    # 獲取整個測試數據集的熵
    baseShnnonEnt = calcShannonEnt(dataset)
    #print("total's shannonEnt = %f" % (baseShnnonEnt))
    # 遍歷每一個特征列,求取信息增益
    for i in range(numFeatures):
        # 獲取某一列所有特征值
        featureList = [example[i] for example in dataset]
        uniqueVals = set(featureList)
        newEntropy = 0.0
        # 求得某一列某一個特征值的概率和熵
        newShannonEnt = 0.0
        for value in uniqueVals:
            # 計算熵
            subDataset = splitDataSet(dataset,i,value)
            newEntropy = calcShannonEnt(subDataset)
            # 計算某一列某一個特征值的概率
            newProbability = len(subDataset) / float(len(dataset))
            newShannonEnt += newProbability*calcShannonEnt(subDataset)
        infoGainRate = baseShnnonEnt - newShannonEnt
        #print("%s's infoGainRate = %f - %f = %f"%(headerList[i],baseShnnonEnt,newShannonEnt,infoGainRate))
        if infoGainRate > bestInfoGainRate:
            bestInfoGainRate = infoGainRate
            bestFeature = i
    return bestFeature

# 標簽判定,通過少數服從多數原則
def majorityCnt(classList):
    classcount = {}
    for cl in classList:
        if cl not in classcount.keys():
            classcount[cl] = 0
        classcount[cl] += 1
    sortedClassCount = sorted(classcount.items(),key = operator.itemgetter(1),reverse= True)
    return sortedClassCount[0][0]

# 創建一個決策樹
def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    # 1 所有特征值都是相同的時候直接返回
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    # 2 遍歷完所有特征值,投票原則,返回出現次數最多的標簽
    if len(dataSet[0])  == 1:
        return majorityCnt(classList)
    # 3 如果不滿足上面兩者,求最優特征
    bestFeature = chooseBestFeatureToSplit(dataSet,labels)
    bestFeatureLabel = labels[bestFeature]
    myTree = {bestFeatureLabel: {}}
    del (labels[bestFeature])
    featurValues = [example[bestFeature] for example in dataSet]
    uniqueVals = set(featurValues)
    # 使用遞歸的方法,獲得整個樹
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatureLabel][value] = createTree(splitDataSet(dataSet, bestFeature, value), subLabels)
    return myTree

def classify(inputTree, featLabels, testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel

def classifyAll(inputTree, featLabels, testDataSet):
    classLabelAll = []
    for testVec in testDataSet:
        classLabelAll.append(classify(inputTree, featLabels, testVec))
    return classLabelAll

def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'wb')
    pickle.dump(inputTree, fw)
    fw.close()

def grabTree(filename):
    import pickle
    fr = open(filename, 'rb')
    return pickle.load(fr)

def main():
    # 讀取數據集
    labels, dataSet = readDataset(file_path=r'D:\test.csv', file_mode='r')
    labels_tmp = labels[:] # 拷貝,createTree會改變labels
    desicionTree = createTree(dataSet, labels_tmp)
    storeTree(desicionTree, 'classifierStorage.txt')
    desicionTree = grabTree('classifierStorage.txt')
    treePlotter.createPlot(desicionTree)
    testSet = [['youth', 'high', 'no', 'fair', 'no']]
    print('classifyResult:\n', classifyAll(desicionTree, labels, testSet))

if __name__ == '__main__':
    main()

五、測試結果及結論

 

  我們從上面求解信息增益的公式中,其實可以看出,信息增益准則其實是對可取值數目較多的屬性有所偏好!
  現在假如我們把數據集中的“編號”也作為一個候選划分屬性。我們可以算出“編號”的信息增益是0.998
  因為每一個樣本的編號都是不同的(由於編號獨特唯一,條件熵為0了,每一個結點中只有一類,純度非常高啊),也就是說,來了一個預測樣本,你只要告訴我編號,其它特征就沒有用了,這樣生成的決策樹顯然不具有泛化能力。

 

  參考鏈接:

  http://www.cnblogs.com/wsine/p/5180310.html

  https://zhuanlan.zhihu.com/p/26760551

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM