如何實現並應用決策樹算法?


本文對決策樹算法進行簡單的總結和梳理,並對著名的決策樹算法ID3(Iterative Dichotomiser 迭代二分器)進行實現,實現采用Python語言,一句老梗,“人生苦短,我用Python”,Python確實能夠省很多語言方面的事,從而可以讓我們專注於問題和解決問題的邏輯。

根據不同的數據,我實現了三個版本的ID3算法,復雜度逐步提升:

1.純標稱值無缺失數據集

2.連續值和標稱值混合且無缺失數據集

3.連續值和標稱值混合,有缺失數據集

第一個算法參考了《機器學習實戰》的大部分代碼,第二、三個算法基於前面的實現進行模塊的增加。

決策樹簡介

決策樹算法不用說大家應該都知道,是機器學習的一個著名算法,由澳大利亞著名計算機科學家Rose Quinlan發表。

決策樹是一種監督學習的分類算法,目的是學習出一顆決策樹,該樹中間節點是數據特征,葉子節點是類別,實際分類時根據樹的結構,一步一步根據當前數據特征取值選擇進入哪一顆子樹,直到走到葉子節點,葉子節點的類別就是此決策樹對此數據的學習結果。下圖就是一顆簡單的決策樹:

此決策樹用來判斷一個具有紋理,觸感,密度的西瓜是否是“好瓜”。

當有這樣一個西瓜,紋理清晰,密度為0.333,觸感硬滑,那么要你判斷是否是一個“好瓜”,這時如果通過決策樹來判斷,顯然可以一直順着紋理->清晰->密度<=0.382->否,即此瓜不是“好瓜”,一次決策就這樣完成了。正因為決策樹決策很方便,並且准確率也較高,所以常常被用來做分類器,也是“機器學習十大算法”之一C4.5的基本思想。

學習出一顆決策樹首要考慮一個問題,即 根據數據集構建當前樹應該選擇哪種屬性作為樹根,即划分標准? 

考慮最好的情況,一開始選擇某個特征,就把數據集划分成功,即在該特征上取某個值的全是一類。

考慮最壞的情況,不斷選擇特征,划分后的數據集總是雜亂無章,就二分類任務來說,總是有正類有負類,一直到特征全部用完了,划分的數據集合還是有正有負,這時只能用投票法,正類多就選正類作為葉子,否則選負類。

所以得出了一般結論: 隨着划分的進行,我們希望選擇一個特征,使得子節點包含的樣本盡可能屬於同一類別,即“純度”越高越好。

基於“純度”的標准不同,有三種算法:

1.ID3算法(Iterative Dichotomiser 迭代二分器),也是本文要實現的算法,基於信息增益即信息熵來度量純度

2.C4.5算法(Classifier 4.5),ID3 的后繼算法,也是昆蘭提出

3.CART算法(Classification And Regression Tree),基於基尼指數度量純度。

ID3算法簡介

信息熵是信息論中的一個重要概念,也叫“香農熵”,香農先生的事跡相比很多人都聽過,一個人開創了一門理論,牛的不行。香農理論中一個很重要的特征就是”熵“,即”信息內容的不確定性“,香農在進行信息的定量計算的時候,明確地把信息量定義為隨機不定性程度的減少。這就表明了他對信息的理解:信息是用來減少隨機不定性的東西。或者表達為香農逆定義:信息是確定性的增加。這也印證了決策樹以熵作為划分選擇的度量標准的正確性,即我們想更快速地從數據中獲得更多信息,我們就應該快速降低不確定性,即減少”熵“。

信息熵定義為:

D表示數據集,類別總數為|Y|,pk表示D中第k類樣本所占的比例。根據其定義,Ent的值越小,信息純度越高。Ent的范圍是[0,log|Y|]

下面要選擇某個屬性進行划分,要依次考慮每個屬性,假設當前考慮屬性a,a的取值有|V|種,那么我們希望取a作為划分屬性,划分到|V|個子節點后,所有子節點的信息熵之和即划分后的信息熵能夠有很大的減小,減小的最多的那個屬性就是我們選擇的屬性。

划分后的信息熵定義為:

 

所以用屬性a對樣本集D進行划分的信息增益就是原來的信息熵減去划分后的信息熵:

ID3算法就是這樣每次選擇一個屬性對樣本集進行划分,知道兩種情況使這個過程停止:

(1)某個子節點樣本全部屬於一類

(2)屬性都用完了,這時候如果子節點樣本還是不一致,那么只好少數服從多數了

(圖片來自網絡)

ID3算法實現(純標稱值)

如果樣本全部是標稱值即離散值的話,會比較簡單。

代碼:

from math import log
from operator import itemgetter
def createDataSet():            #創建數據集
    dataSet = [[1,1,'yes'],
               [1,1,'yes'],
               [1,0,'no'],
               [0,1,'no'],
               [0,1,'no']]
    featname = ['no surfacing', 'flippers']
    return dataSet,featname
def filetoDataSet(filename):
    fr = open(filename,'r')
    all_lines = fr.readlines()
    featname = all_lines[0].strip().split(',')[1:-1]
    print(featname)
    dataSet = []
    for line in all_lines[1:]:
        line = line.strip()
        lis = line.split(',')[1:]
        dataSet.append(lis)
    fr.close()
    return dataSet,featname
def calcEnt(dataSet):           #計算香農熵
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        label = featVec[-1]
        if label not in labelCounts.keys():
            labelCounts[label] = 0
        labelCounts[label] += 1
    Ent = 0.0
    for key in labelCounts.keys():
        p_i = float(labelCounts[key]/numEntries)
        Ent -= p_i * log(p_i,2)
    return Ent
def splitDataSet(dataSet, axis, value):   #划分數據集,找出第axis個屬性為value的數據
    returnSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            retVec = featVec[:axis]
            retVec.extend(featVec[axis+1:])
            returnSet.append(retVec)
    return returnSet
def chooseBestFeat(dataSet):
    numFeat = len(dataSet[0])-1
    Entropy = calcEnt(dataSet)
    DataSetlen = float(len(dataSet))
    bestGain = 0.0
    bestFeat = -1
    for i in range(numFeat):
        allvalue = [featVec[i] for featVec in dataSet]
        specvalue = set(allvalue)
        nowEntropy = 0.0
        for v in specvalue:
            Dv = splitDataSet(dataSet,i,v)
            p = len(Dv)/DataSetlen
            nowEntropy += p * calcEnt(Dv)
        if Entropy - nowEntropy > bestGain:
            bestGain = Entropy - nowEntropy
            bestFeat = i
    return bestFeat
def Vote(classList):
    classdic = {}
    for vote in classList:
        if vote not in classdic.keys():
            classdic[vote] = 0
        classdic[vote] += 1
    sortedclassDic = sorted(classdic.items(),key=itemgetter(1),reverse=True)
    return sortedclassDic[0][0]
def createDecisionTree(dataSet,featnames):
    featname = featnames[:]              ################
    classlist = [featvec[-1] for featvec in dataSet]  #此節點的分類情況
    if classlist.count(classlist[0]) == len(classlist):  #全部屬於一類
        return classlist[0]
    if len(dataSet[0]) == 1:         #分完了,沒有屬性了
        return Vote(classlist)       #少數服從多數
    # 選擇一個最優特征進行划分
    bestFeat = chooseBestFeat(dataSet)
    bestFeatname = featname[bestFeat]
    del(featname[bestFeat])     #防止下標不准
    DecisionTree = {bestFeatname:{}}
    # 創建分支,先找出所有屬性值,即分支數
    allvalue = [vec[bestFeat] for vec in dataSet]
    specvalue = sorted(list(set(allvalue)))  #使有一定順序
    for v in specvalue:
        copyfeatname = featname[:]
        DecisionTree[bestFeatname][v] = createDecisionTree(splitDataSet(dataSet,bestFeat,v),copyfeatname)
    return DecisionTree
if __name__ == '__main__':
    filename = "D:\\MLinAction\\Data\\西瓜2.0.txt"
    DataSet,featname = filetoDataSet(filename)
    #print(DataSet)
    #print(featname)
    Tree = createDecisionTree(DataSet,featname)
    print(Tree)
View Code

解釋一下幾個函數:

filetoDataSet(filename)  將文件中的數據整理成數據集

calcEnt(dataSet)     計算香農熵

splitDataSet(dataSet, axis, value)     划分數據集,選擇出第axis個屬性的取值為value的所有數據集,即D^v,並去掉第axis個屬性,因為不需要了

chooseBestFeat(dataSet)      根據信息增益,選擇一個最好的屬性

Vote(classList)        如果屬性用完,類別仍不一致,投票決定

createDecisionTree(dataSet,featnames)     遞歸創建決策樹

--------------------------------------------------------------------------------

用西瓜數據集2.0對算法進行測試,西瓜數據集見 西瓜數據集2.0,輸出如下:

['色澤', '根蒂', '敲聲', '紋理', '臍部', '觸感']
{'紋理': {'清晰': {'根蒂': {'蜷縮': '是', '硬挺': '否', '稍蜷': {'色澤': {'青綠': '是', '烏黑': {'觸感': {'硬滑': '是', '軟粘': '否'}}}}}}, '稍糊': {'觸感': {'硬滑': '否', '軟粘': '是'}}, '模糊': '否'}}

為了能夠體現決策樹的優越性即決策方便,這里基於matplotlib模塊編寫可視化函數treePlot,對生成的決策樹進行可視化,可視化結果如下:

 

由於數據太少,沒有設置測試數據以驗證其准確度,但是我后面會根據乳腺癌的例子進行准確度的測試的,下面進入下一部分:

有連續值的情況

有連續值的情況如 西瓜數據集3.0 

一個屬性有很多種取值,我們肯定不能每個取值都做一個分支,這時候需要對連續屬性進行離散化,有幾種方法供選擇,其中兩種是:

1.對每一類別的數據集的連續值取平均值,再取各類的平均值的平均值作為划分點,將連續屬性化為兩類變成離散屬性

2.C4.5采用的二分法,排序離散屬性,取每兩個的中點作為划分點的候選點,計算以每個划分點划分數據集的信息增益,取最大的那個划分點將連續屬性化為兩類變成離散屬性,用該屬性進行划分的信息增益就是剛剛計算的最大信息增益。公式如下:

這里采用第二種,並在學習前對連續屬性進行離散化。增加處理的代碼如下:

def splitDataSet_for_dec(dataSet, axis, value, small):
    returnSet = []
    for featVec in dataSet:
        if (small and featVec[axis] <= value) or ((not small) and featVec[axis] > value):
            retVec = featVec[:axis]
            retVec.extend(featVec[axis+1:])
            returnSet.append(retVec)
    return returnSet
def DataSetPredo(filename,decreteindex):
    dataSet,featname = filetoDataSet(filename)
    Entropy = calcEnt(dataSet)
    DataSetlen = len(dataSet)
    for index in decreteindex:     #對每一個是連續值的屬性下標
        for i in range(DataSetlen):
            dataSet[i][index] = float(dataSet[i][index])
        allvalue = [vec[index] for vec in dataSet]
        sortedallvalue = sorted(allvalue)
        T = []
        for i in range(len(allvalue)-1):        #划分點集合
            T.append(float(sortedallvalue[i]+sortedallvalue[i+1])/2.0)
        bestGain = 0.0
        bestpt = -1.0
        for pt in T:          #對每個划分點
            nowent = 0.0
            for small in range(2):   #化為正類負類
                Dt = splitDataSet_for_dec(dataSet, index, pt, small)
                p = len(Dt) / float(DataSetlen)
                nowent += p * calcEnt(Dt)
            if Entropy - nowent > bestGain:
                bestGain = Entropy-nowent
                bestpt = pt
        featname[index] = str(featname[index]+"<="+"%.3f"%bestpt)
        for i in range(DataSetlen):
            dataSet[i][index] = "是" if dataSet[i][index] <= bestpt else "否"
    return dataSet,featname

主要是預處理函數DataSetPredo,對數據集提前離散化,然后再進行學習,學習代碼類似。輸出的決策樹如下:

有缺失值的情況

數據有缺失值是常見的情況,我們不好直接拋棄這些數據,因為這樣會損失大量數據,不划算,但是缺失值我們也無法判斷它的取值。怎么辦呢,辦法還是有的。

考慮兩個問題: 

1.有缺失值時如何進行划分選擇

2.已選擇划分屬性,有缺失值的樣本划不划分,如何划分?

問題1:有缺失值時如何進行划分選擇

基本思想是進行最優屬性選擇時,先只考慮無缺失值樣本,然后再乘以相應比例,得到在整個樣本集上的大致情況。連帶考慮到第二個問題的話,考慮給每一個樣本一個權重,此時每個樣本不再總是被看成一個獨立樣本,這樣有利於第二個問題的解決:即若樣本在屬性a上的值缺失,那么將其看成是所有值都取,只不過取每個值的權重不一樣,每個值的權重參考該值在無缺失值樣本中的比例,簡單地說,比如在無缺失值樣本集中,屬性a取去兩個值1和2,並且取1的權重和占整個權重和1/3,而取2的權重和占2/3,那么依據該屬性對樣本集進行划分時,遇到該屬性上有缺失值的樣本,那么我們認為該樣本取值2的可能性更大,於是將該樣本的權重乘以2/3歸到取值為2的樣本集中繼續進行划分構造決策樹,而乘1/3划到取值為1的樣本集中繼續構造。不知道我說清楚沒有。

公式如下:

其中,D~表示數據集D在屬性a上無缺失值的樣本,根據它來判斷a屬性的優劣,rho(即‘lou')表示屬性a的無缺失值樣本占所有樣本的比例,p~_k表示無缺失值樣本中第k類所占的比例,r~_v表示無缺失值樣本在屬性a上取值為v的樣本所占的比例。

在划分樣本時,如果有缺失值,則將樣本划分到所有子節點,在屬性a取值v的子節點上的權重為r~_v * 原來的權重。

更詳細的解讀參考《機器學習》P86-87。

根據權重法修改后的ID3算法實現如下:

from math import log
from operator import itemgetter

def filetoDataSet(filename):
    fr = open(filename,'r')
    all_lines = fr.readlines()
    featname = all_lines[0].strip().split(',')[1:-1]
    dataSet = []
    for line in all_lines[1:]:
        line = line.strip()
        lis = line.split(',')[1:]
        if lis[-1] == '2':
            lis[-1] = ''
        else:
            lis[-1] = ''
        dataSet.append(lis)
    fr.close()
    return dataSet,featname

def calcEnt(dataSet, weight):           #計算權重香農熵
    labelCounts = {}
    i = 0
    for featVec in dataSet:
        label = featVec[-1]
        if label not in labelCounts.keys():
            labelCounts[label] = 0
        labelCounts[label] += weight[i]
        i += 1
    Ent = 0.0
    for key in labelCounts.keys():
        p_i = float(labelCounts[key]/sum(weight))
        Ent -= p_i * log(p_i,2)
    return Ent

def splitDataSet(dataSet, weight, axis, value, countmissvalue):   #划分數據集,找出第axis個屬性為value的數據
    returnSet = []
    returnweight = []
    i = 0
    for featVec in dataSet:
        if featVec[axis] == '?' and (not countmissvalue):
            continue
        if countmissvalue and featVec[axis] == '?':
            retVec = featVec[:axis]
            retVec.extend(featVec[axis+1:])
            returnSet.append(retVec)
        if featVec[axis] == value:
            retVec = featVec[:axis]
            retVec.extend(featVec[axis+1:])
            returnSet.append(retVec)
            returnweight.append(weight[i])
        i += 1
    return returnSet,returnweight

def splitDataSet_for_dec(dataSet, axis, value, small, countmissvalue):
    returnSet = []
    for featVec in dataSet:
        if featVec[axis] == '?' and (not countmissvalue):
            continue
        if countmissvalue and featVec[axis] == '?':
            retVec = featVec[:axis]
            retVec.extend(featVec[axis+1:])
            returnSet.append(retVec)
        if (small and featVec[axis] <= value) or ((not small) and featVec[axis] > value):
            retVec = featVec[:axis]
            retVec.extend(featVec[axis+1:])
            returnSet.append(retVec)
    return returnSet
            
def DataSetPredo(filename,decreteindex):     #首先運行,權重不變為1
    dataSet,featname = filetoDataSet(filename)
    DataSetlen = len(dataSet)
    Entropy = calcEnt(dataSet,[1 for i in range(DataSetlen)])
    for index in decreteindex:     #對每一個是連續值的屬性下標
        UnmissDatalen = 0
        for i in range(DataSetlen):      #字符串轉浮點數
            if dataSet[i][index] != '?':
                UnmissDatalen += 1
                dataSet[i][index] = int(dataSet[i][index])
        allvalue = [vec[index] for vec in dataSet if vec[index] != '?']
        sortedallvalue = sorted(allvalue)
        T = []
        for i in range(len(allvalue)-1):        #划分點集合
            T.append(int(sortedallvalue[i]+sortedallvalue[i+1])/2.0)
        bestGain = 0.0
        bestpt = -1.0
        for pt in T:          #對每個划分點
            nowent = 0.0
            for small in range(2):   #化為正類(1)負類(0)
                Dt = splitDataSet_for_dec(dataSet, index, pt, small, False)
                p = len(Dt) / float(UnmissDatalen)
                nowent += p * calcEnt(Dt,[1.0 for i in range(len(Dt))])
            if Entropy - nowent > bestGain:
                bestGain = Entropy-nowent
                bestpt = pt
        featname[index] = str(featname[index]+"<="+"%d"%bestpt)
        for i in range(DataSetlen):
            if dataSet[i][index] != '?':
                dataSet[i][index] = "" if dataSet[i][index] <= bestpt else ""
    return dataSet,featname

def getUnmissDataSet(dataSet, weight, axis):
    returnSet = []
    returnweight = []
    tag = []
    i = 0
    for featVec in dataSet:
        if featVec[axis] == '?':
            tag.append(i)
        else:
            retVec = featVec[:axis]
            retVec.extend(featVec[axis+1:])
            returnSet.append(retVec)
        i += 1
    for i in range(len(weight)):
        if i not in tag:
            returnweight.append(weight[i])
    return returnSet,returnweight

def printlis(lis):
    for li in lis:
        print(li)
        
def chooseBestFeat(dataSet,weight,featname):
    numFeat = len(dataSet[0])-1
    DataSetWeight = sum(weight)
    bestGain = 0.0
    bestFeat = -1
    for i in range(numFeat):
        UnmissDataSet,Unmissweight = getUnmissDataSet(dataSet, weight, i)   #無缺失值數據集及其權重
        Entropy = calcEnt(UnmissDataSet,Unmissweight)      #Ent(D~)
        allvalue = [featVec[i] for featVec in dataSet if featVec[i] != '?']
        UnmissSumWeight = sum(Unmissweight)
        lou = UnmissSumWeight / DataSetWeight        #lou
        specvalue = set(allvalue)
        nowEntropy = 0.0
        for v in specvalue:      #該屬性的幾種取值
            Dv,weightVec_v = splitDataSet(dataSet,Unmissweight,i,v,False)   #返回 此屬性為v的所有樣本 以及 每個樣本的權重
            p = sum(weightVec_v) / UnmissSumWeight          #r~_v = D~_v / D~
            nowEntropy += p * calcEnt(Dv,weightVec_v)
        if lou*(Entropy - nowEntropy) > bestGain:
            bestGain = Entropy - nowEntropy
            bestFeat = i
    return bestFeat

def Vote(classList,weight):
    classdic = {}
    i = 0
    for vote in classList:
        if vote not in classdic.keys():
            classdic[vote] = 0
        classdic[vote] += weight[i]
        i += 1
    sortedclassDic = sorted(classdic.items(),key=itemgetter(1),reverse=True)
    return sortedclassDic[0][0]

def splitDataSet_adjustWeight(dataSet,weight,axis,value,r_v):
    returnSet = []
    returnweight = []
    i = 0
    for featVec in dataSet:
        if featVec[axis] == '?':
            retVec = featVec[:axis]
            retVec.extend(featVec[axis+1:])
            returnSet.append(retVec)
            returnweight.append(weight[i] * r_v)
        elif featVec[axis] == value:
            retVec = featVec[:axis]
            retVec.extend(featVec[axis+1:])
            returnSet.append(retVec)
            returnweight.append(weight[i])
        i += 1
    return returnSet,returnweight
    
def createDecisionTree(dataSet,weight,featnames):
    featname = featnames[:]              ################
    classlist = [featvec[-1] for featvec in dataSet]  #此節點的分類情況
    if classlist.count(classlist[0]) == len(classlist):  #全部屬於一類
        return classlist[0]
    if len(dataSet[0]) == 1:         #分完了,沒有屬性了
        return Vote(classlist,weight)       #少數服從多數
    # 選擇一個最優特征進行划分
    bestFeat = chooseBestFeat(dataSet,weight,featname)
    bestFeatname = featname[bestFeat]
    del(featname[bestFeat])     #防止下標不准
    DecisionTree = {bestFeatname:{}}
    # 創建分支,先找出所有屬性值,即分支數
    allvalue = [vec[bestFeat] for vec in dataSet if vec[bestFeat] != '?']
    specvalue = sorted(list(set(allvalue)))  #使有一定順序
    UnmissDataSet,Unmissweight = getUnmissDataSet(dataSet, weight, bestFeat)   #無缺失值數據集及其權重
    UnmissSumWeight = sum(Unmissweight)      # D~
    for v in specvalue:
        copyfeatname = featname[:]
        Dv,weightVec_v = splitDataSet(dataSet,Unmissweight,bestFeat,v,False)   #返回 此屬性為v的所有樣本 以及 每個樣本的權重
        r_v = sum(weightVec_v) / UnmissSumWeight          #r~_v = D~_v / D~
        sondataSet,sonweight = splitDataSet_adjustWeight(dataSet,weight,bestFeat,v,r_v)
        DecisionTree[bestFeatname][v] = createDecisionTree(sondataSet,sonweight,copyfeatname)
    return DecisionTree

if __name__ == '__main__':
    filename = "D:\\MLinAction\\Data\\breastcancer.txt"
    DataSet,featname = DataSetPredo(filename,[0,1,2,3,4,5,6,7,8])
    Tree = createDecisionTree(DataSet,[1.0 for i in range(len(DataSet))],featname)
    print(Tree)
View Code

有缺失值的情況如 西瓜數據集2.0alpha

實驗結果:

在乳腺癌數據集上的測試與表現

有了算法,我們當然想做一定的測試看一看算法的表現。這里我選擇了威斯康辛女性乳腺癌的數據。

數據總共有9列,每一列分別代表,以逗號分割

1 Sample code number (病人ID)
2 Clump Thickness 腫塊厚度
3 Uniformity of Cell Size 細胞大小的均勻性
4 Uniformity of Cell Shape 細胞形狀的均勻性
5 Marginal Adhesion 邊緣粘
6 Single Epithelial Cell Size 單上皮細胞的大小
7 Bare Nuclei 裸核
8 Bland Chromatin 乏味染色體
9 Normal Nucleoli 正常核
10 Mitoses 有絲分裂
11 Class: 2 for benign, 4 formalignant(惡性或良性分類)

[from Toby]

總共700條左右的數據,選取最后80條作為測試集,前面作為訓練集,進行學習。

使用分類器的代碼如下:

import treesID3 as id3
import treePlot as tpl
import pickle

def classify(Tree, featnames, X):
    classLabel = "未知"
    root = list(Tree.keys())[0]
    firstGen = Tree[root]
    featindex = featnames.index(root)  #根節點的屬性下標
    for key in firstGen.keys():   #根屬性的取值,取哪個就走往哪顆子樹
        if X[featindex] == key:
            if type(firstGen[key]) == type({}):
                classLabel = classify(firstGen[key],featnames,X)
            else:
                classLabel = firstGen[key]
    return classLabel

def StoreTree(Tree,filename):
    fw = open(filename,'wb')
    pickle.dump(Tree,fw)
    fw.close()

def ReadTree(filename):
    fr = open(filename,'rb')
    return pickle.load(fr)

if __name__ == '__main__':
    filename = "D:\\MLinAction\\Data\\breastcancer.txt"
    dataSet,featnames = id3.DataSetPredo(filename,[0,1,2,3,4,5,6,7,8])
    Tree = id3.createDecisionTree(dataSet[:620],[1.0 for i in range(len(dataSet))],featnames)
    tpl.createPlot(Tree)
    storetree = "D:\\MLinAction\\Data\\decTree.dect"
    StoreTree(Tree,storetree)
    #Tree = ReadTree(storetree)
    i = 1
    cnt = 0
    for lis in dataSet[620:]:
        judge = classify(Tree,featnames,lis[:-1])
        shouldbe = lis[-1]
        if judge == shouldbe:
            cnt += 1
        print("Test %d was classified %s, it's class is %s %s" %(i,judge,shouldbe,"=====" if judge==shouldbe else ""))
        i += 1
    print("The Tree's Accuracy is %.3f" % (cnt / float(i)))

訓練出的決策樹如下:

最終的正確率可以看到:

正確率約為96%左右,算是不差的分類器了。

我的乳腺癌數據見:http://7xt9qk.com2.z0.glb.clouddn.com/breastcancer.txt

至此,決策樹算法ID3的實現完畢,下面考慮基於基尼指數和信息增益率進行划分選擇,以及考慮實現剪枝過程,因為我們可以看到上面訓練出的決策樹還存在着很多冗余分支,是因為實現過程中,由於數據量太大,每個分支都不完全純凈,所以會創建往下的分支,但是分支投票的結果又是一致的,而且數據量再大,特征數再多的話,決策樹會非常大非常復雜,所以剪枝一般是必做的一步。剪枝分為先剪枝和后剪枝,如果細說的話可以寫很多了。

此文亦可見:這里
參考資料:《機器學習》《機器學習實戰》通過本次實戰也發現了這兩本書中的一些錯誤之處。

lz初學機器學習不久,如有錯漏之處請多包涵指出或者各位有什么想法或意見歡迎評論去告訴我:)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM