決策樹詳解,從熵說起


  熵,一個神奇的工具,用來衡量數據集信息量的不確定性。

  首先,我們先來了解一個指標,信息量。對於任意一個隨機變量X,樣本空間為{X1,X2,...,Xn},樣本空間可以這么理解,也就是隨機變量X所有的可能取值。如果在ML領域內,我們可以把Xi當做X所屬的某一個類。對於任意的樣本Xi(類Xi),樣本Xi的信息量也就是l(Xi) = -log(p(Xi))。由於p(Xi)是為樣本Xi的概率,也可以說是類Xi的概率,那么l(Xi)的取值范圍為(+∞,0]。也就是X的的概率越小,其含有的信息量越大,反之亦然。這也不難理解,Xi的發生的概率越大,我們對他的了解越深,那么他所含有的信息量就越小。如果隨機變量X是常量,那么我們從任意一個Xi都可以獲取樣本空間的值,那么我們可以認為X沒有任何信息量,他的信息量為0。如果說,我們要把隨機變量X樣本空間都了解完,才能獲得X的信息,那么我們可以認為X的信息量“無窮大”,其取值為log(2,n)。

  緊接着,我們就提出了隨機變量X的信息熵,也就是信息量的期望,H(X) = -∑ni=1p(Xi)log2(p(Xi))=∑xp(x)log(p(x)),從離散的角度得出的公式。他有三個特性:

  • 單調性,即發生概率越高的事件,其所攜帶的信息熵越低。極端案例就是“太陽從東方升起”,因為為確定事件,所以不攜帶任何信息量。從信息論的角度,認為這句話沒有消除任何不確定性。也就是樣本空間的p(Xi)均為1。
  • 非負性,即信息熵不能為負。這個很好理解,因為負的信息,即你得知了某個信息后,卻增加了不確定性是不合邏輯的。
  • 累加性,即多隨機事件同時發生存在的總不確定性的量度是可以表示為各事件不確定性的量度的和。

  有了熵這個基礎,那么我們就要考慮決策樹是怎么生成的了。對於隨機變量X的樣本個數為n的空間,每個樣本都有若干個相同的特征,假設有k個。對於任意一個特征,我們可以對其進行划分,假設含有性別變量,那么切分后,性別特征消失,變為確定的值,那么隨機變量X信息的不確定性減少。以此類推,直到達到我們想要的結果結束,這樣就生成了若干棵樹,每棵樹的不確定性降低。如果我們在此過程中進行限制,每次減少的不確定性最大,那么這樣一步一步下來,我們得到的樹,會最快的把不確定性降低到最小。每顆樹的分支,都可以確定一個類別,包含的信息量極少,確定性極大,這種類別,是可以進行預測的,不確定性小,穩定,可以用於預測。

        有了以上的知識儲備,那么我們要想生成一顆決策樹,只需要每次把特征的信息量最大的那個找出來進行划分即可。也就是不確定性最大的那個分支,我們要優先划分。這就會得出另外一個定義,條件信息熵H(Y|X)。

 

根據以上的推導,我們得出信息增益,H(Y)-H(Y|X)。可以看做是特征X的信息量,根據這個的最大值,依次得到每個特征,就是我們需要的決策樹。利用Python完成代如下,打包到一個類下面:

from math import log
import operator

# 計算香農熵
class Tree:
    def __init__(self):
        super()
    def calcShannonEnt(self, dataSet):
        num = len(dataSet)
        labelCounts = {}
        for fVec in dataSet:
            currentLabel = fVec[-1]
            if currentLabel not in labelCounts.keys():
                labelCounts[currentLabel] = 0
            labelCounts[currentLabel] += 1
        shannonEnt = 0.0
        for key in labelCounts:
            prob = float(labelCounts[key]) / num
            shannonEnt -= prob * log(prob, 2)
        return shannonEnt

    #按照特征划分數據集,特征的位置為index
    def splitDataSet(self, dataSet, index, value):
        retDataSet = []
        for featVec in dataSet:
            if featVec[index] == value:
                reducedFeatVec = featVec[:index]
                reducedFeatVec.extend(featVec[index+1:])
                retDataSet.append(reducedFeatVec)
        return retDataSet

    #尋找信息增益最大的特征
    def chooseBestFeatureToSplit(self, dataSet):
        numFeatures = len(dataSet[0]) - 1
        baseEntropy = self.calcShannonEnt(dataSet)
        bestInfoGain, bestFeature = 0.0, -1
        for i in range(numFeatures):
            featList = [example[i] for example in dataSet]
            uniqueVals = set(featList)
            newEntropy = 0.0
            for value in uniqueVals:
                subDataSet = self.splitDataSet(dataSet, i, value)
                prob = len(subDataSet) / float(len(dataSet))
                newEntropy +=prob * self.calcShannonEnt(subDataSet)
            infoGain = baseEntropy - newEntropy
            if (infoGain >= bestInfoGain):#這里注意,取等號,只有1個特征為時,可能無信息增加。
                bestInfoGain = infoGain
                bestFeature = i
        return bestFeature

    # 如果分類不唯一,采用多數表決方法,決定葉子的分類
    def majorityCnt(self, classList):
        classCount = {}
        for vote in classList:
            if vote not in classCount.keys():
                classCount[vote] = 0
            classCount[vote] += 1
        SortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
        return SortedClassCount[0][0]

    # 創建決策樹代碼
    def createTree(self, dataSet, labels):
        classList = [example[-1] for example in dataSet]
        if classList.count(classList[0] == len(classList)):#類別完全相同,無需划分,一類
            return classList[0]
        if len(dataSet[0]) == 1: #處理了所有特征,依舊沒有完全划分,返回多數表決結果
            return self.majorityCnt(classList)
        bestFeat = self.chooseBestFeatureToSplit(dataSet)
        bestFeatLabel = labels[bestFeat]
        myTree = {bestFeatLabel:{}}
        del labels[bestFeat]
        featValues = [example[bestFeat] for example in dataSet]
        uniqueVals = set(featValues)
        for value in uniqueVals:#利用遞歸構建決策樹
            subLabels = labels[:]
            myTree[bestFeatLabel][value] = self.createTree(self.splitDataSet(dataSet, bestFeat, value), subLabels)
        return myTree

    def createDataSet(self):
        dataSet = [
            [1,1,"yes"],
            [1,0,"no"],
            [0,1,"no"],
            [0,1,"no"]
        ]
        labels =["no surfacing", "flippers"]
        return dataSet, labels
    def decisiontreeclassify(self, inputTree, featLabels, testVec):
        firstStr = list(inputTree.keys())[0]
        secondDict = inputTree[firstStr]
        featIndex = featLabels.index(firstStr)
        for key in secondDict.keys():
            if testVec[featIndex] == key:
                if type(secondDict[key]).__name__ == 'dict':
                    classLabel = self.decisiontreeclassify(secondDict[key],featLabels,testVec)
                else:
                    classLabel = secondDict[key]
        return classLabel
if __name__ == "__main__":
    tree = Tree()
    myDat, myLabels =tree.createDataSet()
    inputTree = tree.createTree(myDat, myLabels)
    featLabels = ['no surfacing','flippers']
    print(inputTree)
    print(tree.decisiontreeclassify( inputTree, featLabels, [1,0]))
    print(tree.decisiontreeclassify( inputTree, featLabels, [1,1]))

下面的代碼,是畫出決策樹,便於查看,沒有封裝。

import matplotlib.pyplot as plt
# boxstyle是文本框類型 fc是邊框粗細 sawtooth是鋸齒形
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
# annotate 注釋的意思
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs +=getNumLeafs(secondDict[key])
        else:
            numLeafs +=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth = getTreeDepth(secondDict[key]) +1
        else:
            thisDepth =1
        if thisDepth >maxDepth:
            maxDepth = thisDepth
    return maxDepth

def retrieveTree(i):#創建樹
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]
def plotMidText(cntrPt, parentPt, txtString):#填充文本
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):#i建數據集和標簽
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD     #按比例減少全局變量plotTree.yOff
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW

            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) #繪制此節點帶箭頭的注解
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))  #繪制此節點帶箭頭的注解
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD    #按比例增加全局變量plotTree.yOff

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()
if __name__=="__main__":
    myTree = retrieveTree(0)
    list(myTree.keys())[0]
    createPlot(myTree)

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM