《機器學習實戰》筆記——決策樹(ID3)


閑來無事最近復習了一下ID3決策樹算法,並憑着理解用pandas實現了一遍。對pandas更熟悉的朋友可供參考(鏈接如下)。相比本篇博文,更簡明清晰,更適合復習用。

https://github.com/DianeSoHungry/ShallowMachineLearningCodeItOut/blob/master/ID3.ipynb

 

 

 

現在要介紹的是ID3決策樹算法,只適用於標稱型數據,不適用於數值型數據。

 

決策樹學習算法最大的優點是,他可以自學習,在學習過程中,不需要使用者了解過多的背景知識、領域知識,只需要對訓練實例進行較好的標注就可以自學習了。

 

建立決策樹的關鍵在於當前狀態下選擇哪一個屬性作為分類依據,根據不同的目標函數,有三種主要的算法

ID3(Iterative Dichotomiser)

C4.5

CART(Classification And Regression Tree)

 

問題描述:

下面是一個小型的數據集,5條記錄,2個特征(屬性),有標簽。

根據這個數據集,我們可以建立如下決策樹(用matplotlib的注釋功能畫的)。

觀察決策樹,決策節點為特征,其分支為決策節點的各個不同取值,葉節點為預測值。

建樹結束也就是建立好了一個決策樹分類器,有了分類器,就可以根據這個分類器對其他的魚進行預測了。預測准確性今天暫且不討論。

那么如何建立這樣的決策樹呢?

第一步:建立決策樹。

1.1 利用信息增益尋找當前最佳分類特征

想象現在你是一個判斷結點,你從頭頂的分支上獲得了一個數據集,表中包含標簽和若干屬性。你現在要根據某個屬性來對你接收到的數據集進行分組。到底用哪個屬性來作為划分依據呢?

我們用信息增益來選擇某個節點上用哪個特征來進行分類。

 

什么是信息?

如果待分類的事物可能划分在多個分類中,則每個分類xi的信息定義為:

(這里log前面應該有個負號。)

 

什么是香農熵?

香農熵是所有類別所有可能類別信息的期望值,即:

 

什么是信息增益?

信息增益=原香農熵-新香農熵

 

注意:新香農熵為按照某特征划分之后,每個分支數據集的香農熵之和。

  

可以這樣想:香農熵相當於數據類別(標簽)的混亂程度,信息增益可以衡量划分數據集前后數據(標簽)向有序性發展的程度。因此,回到怎樣利用信息增益尋找當前最佳分類特征的話題,假如你是一個判斷節點,你拿來一個數據集,數據集里面有若干個特征,你需要從中選取一個特征,使得信息增益最大(注意:將數據集中在該特征上取值相同的記錄划分到同一個分支,得到若干個分支數據集,每個分支數據集都有自己的香農熵,各個分支數據集的香農熵的期望才是新香農熵)。要找到這個特征只需要將數據集中的每個特征遍歷一次,求信息增益,取獲得最大信息增益的那個特征。

代碼如下(其中,calcShannonEnt(dataSet)函數用來計算數據集dataSet的香農熵,splitDataSet(dataSet, axis, value)函數將數據集dataSet的第axis列中特征值為value的記錄挑出來,組成分支數據集返回給函數。這兩個函數后面會給出函數定義。):

 1 # 3-3 選擇最好的'數據集划分方式'(特征)
 2 # 一個一個地試每個特征,如果某個按照某個特征分類得到的信息增益(原香農熵-新香農熵)最大,
 3 # 則選這個特征作為最佳數據集划分方式
 4 def chooseBestFeatureToSplit(dataSet):
 5     numFeatures = len(dataSet[0]) - 1
 6     baseEntropy = calcShannonEnt(dataSet)
 7     bestInfoGain = 0.0
 8     bestFeature = -1
 9     for i in range(numFeatures):
10         featList = [example[i] for example in dataSet]
11         uniqueVals = set(featList)
12         newEntropy = 0.0
13         for value in uniqueVals:
14             subDataSet = splitDataSet(dataSet, i, value)
15             prob = len(subDataSet) / float(len(dataSet))
16             newEntropy += prob * calcShannonEnt(subDataSet)
17         infoGain = baseEntropy - newEntropy
18         if (infoGain > bestInfoGain):
19             bestInfoGain = infoGain
20             bestFeature = i
21     return bestFeature

 

calcShannonEnt(dataSet)函數代碼:

 1 def calcShannonEnt(dataSet):
 2     numEntries = len(dataSet)    # 總記錄數
 3     labelCounts = {}    # dataSet中所有出現過的標簽值為鍵,相應標簽值出現過的次數作為值
 4     for featVec in dataSet:
 5         currentLabel = featVec[-1]
 6         labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1
 7     shannonEnt = 0.0
 8     for key in labelCounts:
 9         prob = -float(labelCounts[key])/numEntries
10         shannonEnt += prob * log(prob, 2)
11     return shannonEnt

 

splitDataSet(dataSet, axis, value)函數代碼:

 1 # 3-2 按照給定特征划分數據集(在某個特征axis上,值等於value的所有記錄
 2 # 組成新的數據集retDataSet,新數據集不需要axis這個特征,注意value是特征值,axis指的是特征(所在的列下標))
 3 def splitDataSet(dataSet, axis, value):
 4     retDataSet = []
 5     for featVec in dataSet:
 6         if featVec[axis] == value:
 7             reducedFeatVec = featVec[:axis]
 8             reducedFeatVec.extend(featVec[axis+1:])
 9             retDataSet.append(reducedFeatVec)
10     return retDataSet

 

1.2 建樹

建樹是一個遞歸的過程。

 

遞歸結束的標志(判斷某節點是葉節點的標志):

情況1. 分到該節點的數據集中,所有記錄的標簽列取值都一樣。

情況2. 分到該節點的數據集中,只剩下標簽列。

 

a. 經判斷,若是葉節點,則:

對應情況1,返回數據集中第一條記錄的標簽值(反正所有標簽值都一樣)。

對應情況2,返回數據集中所有標簽值中,出現次數最多的那個標簽值(代碼中,定義一個函數majorityCnt(classList)來實現)

 

b. 經判斷,若不是葉節點,則:

step1. 建立一個字典,字典的鍵為該數據集上選出的最佳特征(划分依據)。

step2. 將具有相同特征值的記錄組成新的數據集(利用splitDataSet(dataSet, axis, value)函數實現,注意期間拋棄了當前用於划分數據的特征列),對新的數據集們進行遞歸建樹。

 

建樹代碼:

 1 # 3-4 創建樹的函數代碼
 2 # 如果非葉子結點,則以當前數據集建樹,並返回該樹。該樹的根節點是一個字典,鍵為划分當前數據集的最佳特征,值為按照鍵值划分后各個數據集構造的樹
 3 # 葉子節點有兩種:1.只剩沒有特征時,葉子節點的返回值為所有記錄中,出現次數最多的那個標簽值 2.該葉子節點中,所有記錄的標簽相同。
 4 
 5 def createTree(dataSet, labels): #label向量的維度為特征數,不是記錄數,是不同列下標對應的特征
 6     classList = [example[-1] for example in dataSet]
 7     if classList.count(classList[0]) == len(classList):
 8         return classList[0]
 9     if len(dataSet[0]) == 1:
10         return majorityCnt(classList)
11     bestFeat = chooseBestFeatureToSplit(dataSet)
12     bestFeatLabel = labels[bestFeat]
13     myTree = {bestFeatLabel: {}}
14     del(labels[bestFeat])
15     featValues = [example[bestFeat] for example in dataSet]
16     uniqueVals = set(featValues)
17     for value in uniqueVals:  #遞歸建子樹,若值為字典,則非葉節點,若為字符串,則為葉節點
18         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels)
19     return myTree

 

用上面給出的數據來建立一顆決策樹做示范:

在同一個程序中輸入如下代碼並運行:

 1 def createDataSet():
 2     dataSet = [[1, 1, 'yes'],
 3                [1, 1, 'yes'],
 4                [1, 0, 'no'],
 5                [0, 1, 'no'],
 6                [0, 1, 'no']]
 7     labels = ['no surfacing', 'flippers']
 8     return dataSet, labels
 9 
10 myDat, labels = createDataSet()
11 myTree = createTree(myDat, labels)
12 print myTree

運行結果為:

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

 若利用后面畫決策樹的代碼可以畫出這顆決策樹:

 

案例:

我們通過建立決策樹來預測患者需要佩戴哪種隱形眼鏡(soft(軟材質)、hard(硬材質)、no lenses(不適合硬性眼睛)),數據集包含下面幾個特征:age(年齡), prescript(近視還是遠視), astigmatic(散光), tearRate(眼淚清除率)

建樹的結果為:

{'tearRate': {'reduced': 'no lenses', 'normal': {'astigmatic': {'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}}, 'myope': 'hard'}}, 'no': {'age': {'pre': 'soft', 'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}, 'young': 'soft'}}}}}}

 

畫出來是這個樣子:

 

 

畫決策樹的代碼(不講)

涉及matplotlib.pyplot模塊中的annotation的用法,點擊鏈接進入官網學習這塊內容的prerequisite。

 1 # _*_coding:utf-8_*_
 2 
 3 # 3-7 plotTree函數
 4 import matplotlib.pyplot as plt
 5 
 6 # 定義節點和箭頭格式的常量
 7 decisionNode = dict(boxstyle="sawtooth", fc="0.8")
 8 leafNode = dict(boxstyle="round4", fc="0.8")
 9 arrow_args = dict(arrowstyle="<-")
10 
11 
12 def plotMidTest(cntrPt, parentPt,txtString):
13     xMid = (parentPt[0] + cntrPt[0])/2.0
14     yMid = (parentPt[1] + cntrPt[1])/2.0
15     createPlot.ax1.text(xMid, yMid, txtString)
16 
17 # 繪制自身
18 # 若當前子節點不是葉子節點,遞歸
19 # 若當子節點為葉子節點,繪制該節點
20 def plotTree(myTree, parentPt, nodeTxt):
21     numLeafs = getNumLeafs(myTree)
22     # depth = getTreeDepth(myTree)
23     firstStr = myTree.keys()[0]
24     cntrPt = (plotTree.xoff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yoff)
25     plotMidTest(cntrPt, parentPt, nodeTxt)
26     plotNode(firstStr, cntrPt, parentPt, decisionNode)
27     secondDict = myTree[firstStr]
28     plotTree.yoff = plotTree.yoff - 1.0/plotTree.totalD
29     for key in secondDict.keys():
30         if type(secondDict[key]).__name__=='dict':
31             plotTree(secondDict[key], cntrPt, str(key))
32         else:
33             plotTree.xoff = plotTree.xoff + 1.0/plotTree.totalW
34             plotNode(secondDict[key], (plotTree.xoff, plotTree.yoff), cntrPt, leafNode)
35             plotMidTest((plotTree.xoff, plotTree.yoff), cntrPt, str(key))
36     plotTree.yoff = plotTree.yoff + 1.0/plotTree.totalD
37 
38 
39 # figure points
40 # 畫結點的模板
41 def plotNode(nodeTxt, centerPt, parentPt, nodeType):
42     createPlot.ax1.annotate(nodeTxt,  # 注釋的文字,(一個字符串)
43                             xy=parentPt,  # 被注釋的地方(一個坐標)
44                             xycoords='axes fraction',  # xy所用的坐標系
45                             xytext=centerPt,  # 插入文本的地方(一個坐標)
46                             textcoords='axes fraction', # xytext所用的坐標系
47                             va="center",
48                             ha="center",
49                             bbox=nodeType,  # 注釋文字用的框的格式
50                             arrowprops=arrow_args)  # 箭頭屬性
51 
52 
53 def createPlot(inTree):
54     fig = plt.figure(1, facecolor='white')
55     fig.clf()
56     axprops = dict(xticks=[], yticks=[])
57     createPlot.ax1 = plt.subplot(111,frameon=False, **axprops)
58     plotTree.totalW = float(getNumLeafs(inTree))
59     plotTree.totalD = float(getTreeDepth(inTree))
60     plotTree.xoff = -0.5/plotTree.totalW
61     plotTree.yoff = 1.0
62 
63     plotTree(inTree, (0.5, 1.0),'') #樹的引用作為父節點,但不畫出來,所以用''
64     plt.show()
65 
66 def getNumLeafs(myTree):
67     numLeafs = 0
68     firstStr = myTree.keys()[0]
69     secondDict = myTree[firstStr]
70     for key in secondDict.keys():
71         if type(secondDict[key]).__name__ =='dict':
72             numLeafs += getNumLeafs(secondDict[key])
73         else:
74             numLeafs += 1
75     return numLeafs
76 
77 # 子樹中樹高最大的那一顆的高度+1作為當前數的高度
78 def getTreeDepth(myTree):
79     maxDepth = 0    #用來記錄最高子樹的高度+1
80     firstStr = myTree.keys()[0]
81     secondDict = myTree[firstStr]
82     for key in secondDict.keys():
83         if type(secondDict[key]).__name__ == 'dict':
84             thisDepth = 1 + getTreeDepth(secondDict[key])
85         else:
86             thisDepth = 1
87         if(thisDepth > maxDepth):
88             maxDepth = thisDepth
89     return maxDepth
90 
91 # 方便測試用的人造測試樹
92 def retrieveTree(i):
93     listofTrees = [{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
94                    {'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}
95                    ]
96     return listofTrees[i]

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM