決策樹---ID3算法(介紹及Python實現)


決策樹---ID3算法

 

決策樹:

以天氣數據庫的訓練數據為例。

 

Outlook

Temperature

Humidity

Windy

PlayGolf?

sunny

85

85

FALSE

no

sunny

80

90

TRUE

no

overcast

83

86

FALSE

yes

rainy

70

96

FALSE

yes

rainy

68

80

FALSE

yes

rainy

65

70

TRUE

no

overcast

64

65

TRUE

yes

sunny

72

95

FALSE

no

sunny

69

70

FALSE

yes

rainy

75

80

FALSE

yes

sunny

75

70

TRUE

yes

overcast

72

90

TRUE

yes

overcast

81

75

FALSE

yes

rainy

71

91

TRUE

no

這個例子是根據報告天氣條件的記錄來決定是否外出打高爾夫球。

 

作為分類器,決策樹是一棵有向無環樹。

由根節點、葉節點、內部點、分割屬性、分割判斷規則構成

 

生成階段:決策樹的構建和決策樹的修剪。

根據分割方法的不同:有基於信息論(Information Theory的方法和基於最小GINI指數(lowest GINI index的方法。對應前者的常見方法有ID3、C4.5,后者的有CART

 ID3 算法

       ID3的基本概念是:

1)  決策樹中的每一個非葉子節點對應着一個特征屬性,樹枝代表這個屬性的值。一個葉節點代表從樹根到葉節點之間的路徑所對應的記錄所屬的類別屬性值。這就是決策樹的定義。

2)  在決策樹中,每一個非葉子節點都將與屬性中具有最大信息量的特征屬性相關聯。

3)  熵通常是用於測量一個非葉子節點的信息量大小的名詞。

 

熱力學中表征物質狀態的參量之一,用符號S表示,其物理意義是體系混亂程度的度量。熱力學第二定律(second law of thermodynamics),熱力學基本定律之一,又稱“熵增定律”,表明在自然過程中,一個孤立系統的總混亂度(即“熵”)不會減小。

在信息論中,變量的不確定性越大,熵也就越大,把它搞清楚所需要的信息量也就越大。信息熵是信息論中用於度量信息量的一個概念。一個系統越是有序,信息熵就越低;反之,一個系統越是混亂,信息熵就越高。所以,信息熵也可以說是系統有序化程度的一個度量。

 信息增益的計算

定義1:若存在個相同概率的消息,則每個消息的概率是,一個消息傳遞的信息量為。若有16個事件,則,需要4個比特來代表一個消息。

定義2若給定概率分布則由該分布傳遞的信息量稱為的熵,即

 

例:若是,則是1;若是,則是0.92;若

是,則是0(注意概率分布越均勻,其信息量越大)

定義3若一個記錄的集合根據類別屬性的值被分為相互獨立的類,則識別的一個元素所屬哪個類別所需要的信息量是,其中是的概率分布,即

 

 

仍以天氣數據庫的數據為例。我們統計了14天的氣象數據(指標包括outlook,temperature,humidity,windy),並已知這些天氣是否打球(play)。如果給出新一天的氣象指標數據,判斷一下會不會去打球。在沒有給定任何天氣信息時,根據歷史數據,我們知道一天中打球的概率是9/14,不打的概率是5/14。此時的熵為:

 

定義4:若我們根據某一特征屬性將分成集合,則確定中的一個元素類的信息量可通過確定的加權平均值來得到,即的加權平均值為:

 

 

 

Outlook

temperature

humidity

windy

play

 

yes

no

 

 

 

yes

no

yes

no

sunny

2

3

False

6

2

9

5

overcast

4

0

True

3

3

 

 

rainy

3

2

 

 

 

 

 

 

針對屬性Outlook,我們來計算

定義5:將信息增益定義為:

 

即增益的定義是兩個信息量之間的差值,其中一個信息量是需確定的一個元素的信息量,另一個信息量是在已得到的屬性的值后確定的一個元素的信息量,即信息增益與屬性相關。

針對屬性Outlook的增益值:

 

若用屬性windy替換outlook,可以得到,。即outlook比windy取得的信息量大。

ID3算法的Python實現

import math 
import operator

def calcShannonEnt(dataset):
    numEntries = len(dataset)
    labelCounts = {}
    for featVec in dataset:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] +=1
        
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob*math.log(prob, 2)
    return shannonEnt
    
def CreateDataSet():
    dataset = [[1, 1, 'yes' ], 
               [1, 1, 'yes' ], 
               [1, 0, 'no'], 
               [0, 1, 'no'], 
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataset, labels

def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    
    return retDataSet

def chooseBestFeatureToSplit(dataSet):
    numberFeatures = len(dataSet[0])-1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0;
    bestFeature = -1;
    for i in range(numberFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy =0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

def majorityCnt(classList):
    classCount ={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]=1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) 
    return sortedClassCount[0][0]
 

def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0])==len(classList):
        return classList[0]
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree

        
        
myDat,labels = CreateDataSet()
createTree(myDat,labels)

運行結果如下:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM