決策樹的python實現


決策樹

算法優缺點:

  • 優點:計算復雜度不高,輸出結果易於理解,對中間值缺失不敏感,可以處理不相關的特征數據

  • 缺點:可能會產生過度匹配的問題

  • 適用數據類型:數值型和標稱型

算法思想:

1.決策樹構造的整體思想:

決策樹說白了就好像是if-else結構一樣,它的結果就是你要生成這個一個可以從根開始不斷判斷選擇到葉子節點的樹,但是呢這里的if-else必然不會是讓我們認為去設置的,我們要做的是提供一種方法,計算機可以根據這種方法得到我們所需要的決策樹。這個方法的重點就在於如何從這么多的特征中選擇出有價值的,並且按照最好的順序由根到葉選擇。完成了這個我們也就可以遞歸構造一個決策樹了

2.信息增益

划分數據集的最大原則是將無序的數據變得更加有序。既然這又牽涉到信息的有序無序問題,自然要想到想弄的信息熵了。這里我們計算用的也是信息熵(另一種方法是基尼不純度)。公式如下:

數據需要滿足的要求:

1 數據必須是由列表元素組成的列表,而且所有的列白哦元素都要具有相同的數據長度
2 數據的最后一列或者每個實例的最后一個元素應是當前實例的類別標簽

函數:

calcShannonEnt(dataSet)
計算數據集的香農熵,分兩步,第一步計算頻率,第二部根據公式計算香農熵
splitDataSet(dataSet, aixs, value)
划分數據集,將滿足X[aixs]==value的值都划分到一起,返回一個划分好的集合(不包括用來划分的aixs屬性,因為不需要)
chooseBestFeature(dataSet)
選擇最好的屬性進行划分,思路很簡單就是對每個屬性都划分下,看哪個好。這里使用到了一個set來選取列表中唯一的元素,這是一中很快的方法
majorityCnt(classList)
因為我們遞歸構建決策樹是根據屬性的消耗進行計算的,所以可能會存在最后屬性用完了,但是分類還是沒有算完,這時候就會采用多數表決的方式計算節點分類
createTree(dataSet, labels)
基於遞歸構建決策樹。這里的label更多是對於分類特征的名字,為了更好看和后面的理解。

  1.  1 #coding=utf-8
     2 import operator
     3 from math import log
     4 import time
     5 
     6 def createDataSet():
     7     dataSet=[[1,1,'yes'],
     8             [1,1,'yes'],
     9             [1,0,'no'],
    10             [0,1,'no'],
    11             [0,1,'no']]
    12     labels = ['no surfaceing','flippers']
    13     return dataSet, labels
    14 
    15 #計算香農熵
    16 def calcShannonEnt(dataSet):
    17     numEntries = len(dataSet)
    18     labelCounts = {}
    19     for feaVec in dataSet:
    20         currentLabel = feaVec[-1]
    21         if currentLabel not in labelCounts:
    22             labelCounts[currentLabel] = 0
    23         labelCounts[currentLabel] += 1
    24     shannonEnt = 0.0
    25     for key in labelCounts:
    26         prob = float(labelCounts[key])/numEntries
    27         shannonEnt -= prob * log(prob, 2)
    28     return shannonEnt
    29 
    30 def splitDataSet(dataSet, axis, value):
    31     retDataSet = []
    32     for featVec in dataSet:
    33         if featVec[axis] == value:
    34             reducedFeatVec = featVec[:axis]
    35             reducedFeatVec.extend(featVec[axis+1:])
    36             retDataSet.append(reducedFeatVec)
    37     return retDataSet
    38     
    39 def chooseBestFeatureToSplit(dataSet):
    40     numFeatures = len(dataSet[0]) - 1#因為數據集的最后一項是標簽
    41     baseEntropy = calcShannonEnt(dataSet)
    42     bestInfoGain = 0.0
    43     bestFeature = -1
    44     for i in range(numFeatures):
    45         featList = [example[i] for example in dataSet]
    46         uniqueVals = set(featList)
    47         newEntropy = 0.0
    48         for value in uniqueVals:
    49             subDataSet = splitDataSet(dataSet, i, value)
    50             prob = len(subDataSet) / float(len(dataSet))
    51             newEntropy += prob * calcShannonEnt(subDataSet)
    52         infoGain = baseEntropy -newEntropy
    53         if infoGain > bestInfoGain:
    54             bestInfoGain = infoGain
    55             bestFeature = i
    56     return bestFeature
    57             
    58 #因為我們遞歸構建決策樹是根據屬性的消耗進行計算的,所以可能會存在最后屬性用完了,但是分類
    59 #還是沒有算完,這時候就會采用多數表決的方式計算節點分類
    60 def majorityCnt(classList):
    61     classCount = {}
    62     for vote in classList:
    63         if vote not in classCount.keys():
    64             classCount[vote] = 0
    65         classCount[vote] += 1
    66     return max(classCount)         
    67     
    68 def createTree(dataSet, labels):
    69     classList = [example[-1] for example in dataSet]
    70     if classList.count(classList[0]) ==len(classList):#類別相同則停止划分
    71         return classList[0]
    72     if len(dataSet[0]) == 1:#所有特征已經用完
    73         return majorityCnt(classList)
    74     bestFeat = chooseBestFeatureToSplit(dataSet)
    75     bestFeatLabel = labels[bestFeat]
    76     myTree = {bestFeatLabel:{}}
    77     del(labels[bestFeat])
    78     featValues = [example[bestFeat] for example in dataSet]
    79     uniqueVals = set(featValues)
    80     for value in uniqueVals:
    81         subLabels = labels[:]#為了不改變原始列表的內容復制了一下
    82         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, 
    83                                         bestFeat, value),subLabels)
    84     return myTree
    85     
    86 def main():
    87     data,label = createDataSet()
    88     t1 = time.clock()
    89     myTree = createTree(data,label)
    90     t2 = time.clock()
    91     print myTree
    92     print 'execute for ',t2-t1
    93 if __name__=='__main__':
    94     main()

     

    機器學習筆記索引




免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM