決策樹原理實例(python代碼實現)


決策樹原理實例(python代碼實現)

決策數(Decision Tree)在機器學習中也是比較常見的一種算法,屬於監督學習中的一種。看字面意思應該也比較容易理解,相比其他算法比如支持向量機(SVM)或神經網絡,似乎決策樹感覺“親切”許多。

  • 優點:計算復雜度不高,輸出結果易於理解,對中間值的缺失值不敏感,可以處理不相關特征數據。
  • 缺點:可能會產生過度匹配的問題。
  • 使用數據類型:數值型和標稱型。

簡單介紹完畢,讓我們來通過一個例子讓決策樹“原形畢露”。

一天,老師問了個問題,只根據頭發和聲音怎么判斷一位同學的性別。 
為了解決這個問題,同學們馬上簡單的統計了7位同學的相關特征,數據如下:

頭發 聲音 性別

機智的同學A想了想,先根據頭發判斷,若判斷不出,再根據聲音判斷,於是畫了一幅圖,如下: 
同學A 
於是,一個簡單、直觀的決策樹就這么出來了。頭發長、聲音粗就是男生;頭發長、聲音細就是女生;頭發短、聲音粗是男生;頭發短、聲音細是女生。 
原來機器學習中決策樹就這玩意,這也太簡單了吧。。。 
這時又蹦出個同學B,想先根據聲音判斷,然后再根據頭發來判斷,如是大手一揮也畫了個決策樹: 
同學B 
同學B的決策樹:首先判斷聲音,聲音細,就是女生;聲音粗、頭發長是男生;聲音粗、頭發長是女生。

那么問題來了:同學A和同學B誰的決策樹好些?計算機做決策樹的時候,面對多個特征,該如何選哪個特征為最佳的划分特征?

划分數據集的大原則是:將無序的數據變得更加有序。 
我們可以使用多種方法划分數據集,但是每種方法都有各自的優缺點。於是我們這么想,如果我們能測量數據的復雜度,對比按不同特征分類后的數據復雜度,若按某一特征分類后復雜度減少的更多,那么這個特征即為最佳分類特征。 
Claude Shannon 定義了熵(entropy)和信息增益(information gain)。 
用熵來表示信息的復雜度,熵越大,則信息越復雜。公式如下: 
熵 
信息增益(information gain),表示兩個信息熵的差值。 
首先計算未分類前的熵,總共有8位同學,男生3位,女生5位。 
熵(總)=-3/8*log2(3/8)-5/8*log2(5/8)=0.9544 
接着分別計算同學A和同學B分類后信息熵。 
同學A首先按頭發分類,分類后的結果為:長頭發中有1男3女。短頭發中有2男2女。 
熵(同學A長發)=-1/4*log2(1/4)-3/4*log2(3/4)=0.8113 
熵(同學A短發)=-2/4*log2(2/4)-2/4*log2(2/4)=1 
熵(同學A)=4/8*0.8113+4/8*1=0.9057 
信息增益(同學A)=熵(總)-熵(同學A)=0.9544-0.9057=0.0487 
同理,按同學B的方法,首先按聲音特征來分,分類后的結果為:聲音粗中有3男3女。聲音細中有0男2女。 
熵(同學B聲音粗)=-3/6*log2(3/6)-3/6*log2(3/6)=1 
熵(同學B聲音粗)=-2/2*log2(2/2)=0 
熵(同學B)=6/8*1+2/8*0=0.75 
信息增益(同學B)=熵(總)-熵(同學A)=0.9544-0.75=0.2087

按同學B的方法,先按聲音特征分類,信息增益更大,區分樣本的能力更強,更具有代表性。 
以上就是決策樹ID3算法的核心思想。 
接下來用python代碼來實現ID3算法:

復制代碼
 1 from math import log
 2 import operator
 3 
 4 def calcShannonEnt(dataSet):  # 計算數據的熵(entropy)
 5     numEntries=len(dataSet)  # 數據條數
 6     labelCounts={}
 7     for featVec in dataSet:
 8         currentLabel=featVec[-1] # 每行數據的最后一個字(類別)
 9         if currentLabel not in labelCounts.keys():
10             labelCounts[currentLabel]=0
11         labelCounts[currentLabel]+=1  # 統計有多少個類以及每個類的數量
12     shannonEnt=0
13     for key in labelCounts:
14         prob=float(labelCounts[key])/numEntries # 計算單個類的熵值
15         shannonEnt-=prob*log(prob,2) # 累加每個類的熵值
16     return shannonEnt
17 
18 def createDataSet1():    # 創造示例數據
19     dataSet = [['長', '粗', '男'],
20                ['短', '粗', '男'],
21                ['短', '粗', '男'],
22                ['長', '細', '女'],
23                ['短', '細', '女'],
24                ['短', '粗', '女'],
25                ['長', '粗', '女'],
26                ['長', '粗', '女']]
27     labels = ['頭發','聲音']  #兩個特征
28     return dataSet,labels
29 
30 def splitDataSet(dataSet,axis,value): # 按某個特征分類后的數據
31     retDataSet=[]
32     for featVec in dataSet:
33         if featVec[axis]==value:
34             reducedFeatVec =featVec[:axis]
35             reducedFeatVec.extend(featVec[axis+1:])
36             retDataSet.append(reducedFeatVec)
37     return retDataSet
38 
39 def chooseBestFeatureToSplit(dataSet):  # 選擇最優的分類特征
40     numFeatures = len(dataSet[0])-1
41     baseEntropy = calcShannonEnt(dataSet)  # 原始的熵
42     bestInfoGain = 0
43     bestFeature = -1
44     for i in range(numFeatures):
45         featList = [example[i] for example in dataSet]
46         uniqueVals = set(featList)
47         newEntropy = 0
48         for value in uniqueVals:
49             subDataSet = splitDataSet(dataSet,i,value)
50             prob =len(subDataSet)/float(len(dataSet))
51             newEntropy +=prob*calcShannonEnt(subDataSet)  # 按特征分類后的熵
52         infoGain = baseEntropy - newEntropy  # 原始熵與按特征分類后的熵的差值
53         if (infoGain>bestInfoGain):   # 若按某特征划分后,熵值減少的最大,則次特征為最優分類特征
54             bestInfoGain=infoGain
55             bestFeature = i
56     return bestFeature
57 
58 def majorityCnt(classList):    #按分類后類別數量排序,比如:最后分類為2男1女,則判定為男;
59     classCount={}
60     for vote in classList:
61         if vote not in classCount.keys():
62             classCount[vote]=0
63         classCount[vote]+=1
64     sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
65     return sortedClassCount[0][0]
66 
67 def createTree(dataSet,labels):
68     classList=[example[-1] for example in dataSet]  # 類別:男或女
69     if classList.count(classList[0])==len(classList):
70         return classList[0]
71     if len(dataSet[0])==1:
72         return majorityCnt(classList)
73     bestFeat=chooseBestFeatureToSplit(dataSet) #選擇最優特征
74     bestFeatLabel=labels[bestFeat]
75     myTree={bestFeatLabel:{}} #分類結果以字典形式保存
76     del(labels[bestFeat])
77     featValues=[example[bestFeat] for example in dataSet]
78     uniqueVals=set(featValues)
79     for value in uniqueVals:
80         subLabels=labels[:]
81         myTree[bestFeatLabel][value]=createTree(splitDataSet\
82                             (dataSet,bestFeat,value),subLabels)
83     return myTree
84 
85 
86 if __name__=='__main__':
87     dataSet, labels=createDataSet1()  # 創造示列數據
88     print(createTree(dataSet, labels))  # 輸出決策樹模型結果
復制代碼

 

輸出結果為:

1 {'聲音': {'細': '女', '粗': {'頭發': {'短': '男', '長': '女'}}}}

這個結果的意思是:首先按聲音分類,聲音細為女生;然后再按頭發分類:聲音粗,頭發短為男生;聲音粗,頭發長為女生。 
這個結果也正是同學B的結果。 
補充說明:判定分類結束的依據是,若按某特征分類后出現了最終類(男或女),則判定分類結束。使用這種方法,在數據比較大,特征比較多的情況下,很容易造成過擬合,於是需進行決策樹枝剪,一般枝剪方法是當按某一特征分類后的熵小於設定值時,停止分類。

ID3算法存在的缺點: 
1. ID3算法在選擇根節點和內部節點中的分支屬性時,采用信息增益作為評價標准。信息增益的缺點是傾向於選擇取值較多是屬性,在有些情況下這類屬性可能不會提供太多有價值的信息。 
2. ID3算法只能對描述屬性為離散型屬性的數據集構造決策樹 。

為了改進決策樹,又提出了ID4.5算法和CART算法。之后有時間會介紹這兩種算法。

參考: 
Machine Learning in Action 
統計學習方法

轉載:http://blog.csdn.net/csqazwsxedc/article/details/65697652


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM