python實現決策樹ID3算法


一、決策樹概論
決策樹是根據訓練數據集,按屬性跟類型,構建一棵樹形結構。可以按照這棵樹的結構,對測試數據進行分類。同時決策樹也可以用來處理預測問題(回歸)。

二、決策樹ID3的原理
有多種類型的決策樹,本文介紹的是ID3算法。
首先按照“信息增益”找出最有判別力的屬性,把這個屬性作為根節點,屬性的所有取值作為該根節點的分支,把樣例分成多個子集,每個子集又是一個子樹。以此遞歸,一直進行到所有子集僅包含同一類型的數據為止。最后得到一棵決策樹。ID3主要是按照按照每個屬性的信息增益值最大的屬性作為根節點進行划分。
ID3的算法思路
1、對當前訓練集,計算各屬性的信息增益(假設有屬性A1,A2,…An);
2、選擇信息增益最大的屬性Ak(1<=k<=n),作為根節點;
3、把在Ak處取值相同的例子歸於同一子集,作為該節點的一個樹枝,Ak取幾個值就得幾個子集;
4、若在某個子集中的所有樣本都是屬於同一個類型(本位只討論正(Y)、反(N)兩種類型的情況),則給該分支標上類型號作為葉子節點;
5、對於同時含有多種(兩種)類型的子集,則遞歸調用該算法思路來完成樹的構造。
使用決策樹對一下數據進行分類
如圖:1表示數據集的屬性,有4個屬性(outlook)
如上圖:
1表示數據集的屬性,有4個屬性(outlook、temperature、humidity、windy);
2是二維矩陣,每行表示一個訓練樣本數據,每列表示各個測試樣本的某個屬性值(編號3除外),例如outlook這個屬性有3個取值(sunny,rain,overcast)
3是各個訓練樣本的類型(這里只有兩種類型Y,N)
4是測試樣本,要求我們求出各個測試樣本的類型(分類)
求解步驟
1、計算信息熵
這里寫圖片描述
按照該公式,計算上面數據的信息熵。有上圖2中測試樣本的數據類型只有兩種(Y,N)所以,X=[Y,N],測試數據一共有7行,期中Y類型有4個,N類型有3個。
H(X) = -p(Y)log2p(Y)-p(X)log2p(X)=-4/7*log2(4/7)-3/7*log2(3/7)
2、計算各個屬性的信息增益
例如:對於測試集的第一列(屬性outlook),有3種取值
屬性outlook的信息增益值為:
g(X|A=”outlook”)=H(X)-H(X|A=”outlook”)
這里寫圖片描述
期中1公式表示的是outlook值等於sunny 的情況,2表示的是值等於overcast情況,3表示值等於rain情況。
1項中的2/7表示該值的樣本有2個,總樣本有7個;
2/2表示這兩個樣本中有2個是屬於N類型,0/2是表示有0個是屬於Y類型。
3、按照以上的公式,求出根節點
g(X|A=”outlook”)
g(X|A=”temperature”)
g(X|A=” humidity”)
g(X|A=”windy”)
4、在對不是同一類型的數據進行遞歸建樹
這里寫圖片描述
如上圖,第一次求出第一個節點“outlook”,該節點有三個分支,期中第一個分支sunny的數據都是屬於N類型,所歸為一類;同樣第二個分支overcast屬於同一類型(Y),也歸為一類;都標上類型符作為葉子節點。而第三個分支windy中既有N類型,也有Y類型,所以需要繼續對outlook=”windy”的進行遞歸調用以上算法。最終得到上圖的決策樹。
5、對測試集按照前面建好的決策樹進行分類
例如第一行測試數據的outlook屬性的值是“sunny”,所以預測是屬於N類型;同理第2、3…行測試樣本的結果為N,Y, N, Y, Y, N。
python編程實現
(代碼來自《機器學習實戰》)
1、從txt文件中讀取訓練集數據,並生成二維列表

#讀取數據文檔中的訓練數據(生成二維列表)
def createTrainData():
    lines_set = open('../data/ID3/Dataset.txt').readlines()
    labelLine = lines_set[2];
    labels = labelLine.strip().split()
    lines_set = lines_set[4:11]
    dataSet = [];
    for line in lines_set:
        data = line.split();
        dataSet.append(data);
    return dataSet, labels

代碼分析:
第一行:讀取txt文件,每行作為一個元素組成列表復制給lines_set;
第二行:lines_set[2]里存放的是各屬性名(outlook、temperature、humidity、windy);
labelLine.strip().split():為對讀取到的一行字符串,按空格對字符串進行切割(並去掉字符串),labels 是存放所有屬性名的列表;
lines_set[4:11]:為訓練樣本;
dataSet:使用二維矩陣(列表)存放訓練樣本的數據(包括各屬性的值已經類型)
2、讀取測試集數據

#讀取數據文檔中的測試數據(生成二維列表)
def createTestData():
    lines_set = open('../data/ID3/Dataset.txt').readlines()
    lines_set = lines_set[15:22]
    dataSet = [];
    for line in lines_set:
        data = line.strip().split();
        dataSet.append(data);
    return dataSet

3、計算給定數據的熵函數

def calcShannonEnt(dataSet):
    numEntries = len(dataSet)  
    labelCounts = {}  #類別字典(類別的名稱為鍵,該類別的個數為值)
    for featVec in dataSet:
        currentLabel = featVec[-1]  
        if currentLabel not in labelCounts.keys():  #還沒添加到字典里的類型
            labelCounts[currentLabel] = 0;
        labelCounts[currentLabel] += 1;
    shannonEnt = 0.0  
    for key in labelCounts:  #求出每種類型的熵
        prob = float(labelCounts[key])/numEntries  #每種類型個數占所有的比值
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt;  #返回熵

4、划分數據集,按照給定的特征划分數據集

#按照給定的特征划分數據集
def splitDataSet(dataSet, axis, value):
    retDataSet = []  
    for featVec in dataSet:  #按dataSet矩陣中的第axis列的值等於value的分數據集
        if featVec[axis] == value:      #值等於value的,每一行為新的列表(去除第axis個數據)
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])  
            retDataSet.append(reducedFeatVec) 
    return retDataSet  #返回分類后的新矩陣

5、選擇最好的數據集划分方式

#選擇最好的數據集划分方式
def chooseBestFeatureToSplit(dataSet):  
    numFeatures = len(dataSet[0])-1  #求屬性的個數
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0; bestFeature = -1  
    for i in range(numFeatures):  #求所有屬性的信息增益
        featList = [example[i] for example in dataSet]  
        uniqueVals = set(featList)  #第i列屬性的取值(不同值)數集合
        newEntropy = 0.0  
        for value in uniqueVals:  #求第i列屬性每個不同值的熵*他們的概率
            subDataSet = splitDataSet(dataSet, i , value)  
            prob = len(subDataSet)/float(len(dataSet))  #求出該值在i列屬性中的概率
            newEntropy += prob * calcShannonEnt(subDataSet)  #求i列屬性各值對於的熵求和
        infoGain = baseEntropy - newEntropy  #求出第i列屬性的信息增益
        if(infoGain > bestInfoGain):  #保存信息增益最大的信息增益值以及所在的下表(列值i)
            bestInfoGain = infoGain  
            bestFeature = i  

    return bestFeature  

6、遞歸創建樹
6.1、找出出現次數最多的分類名稱的函數

#找出出現次數最多的分類名稱
def majorityCnt(classList):  
    classCount = {}  
    for vote in classList:  
        if vote not in classCount.keys(): classCount[vote] = 0  
        classCount[vote] += 1  
    sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0] 

6.2、用於創建樹的函數代碼

#創建樹
def createTree(dataSet, labels):  
    classList = [example[-1] for example in dataSet];    #創建需要創建樹的訓練數據的結果列表(例如最外層的列表是[N, N, Y, Y, Y, N, Y])
    if classList.count(classList[0]) == len(classList):  #如果所有的訓練數據都是屬於一個類別,則返回該類別
        return classList[0];  
    if (len(dataSet[0]) == 1):  #訓練數據只給出類別數據(沒給任何屬性值數據),返回出現次數最多的分類名稱
        return majorityCnt(classList);

    bestFeat = chooseBestFeatureToSplit(dataSet);   #選擇信息增益最大的屬性進行分(返回值是屬性類型列表的下標)
    bestFeatLabel = labels[bestFeat]  #根據下表找屬性名稱當樹的根節點
    myTree = {bestFeatLabel:{}}  #以bestFeatLabel為根節點建一個空樹
    del(labels[bestFeat])  #從屬性列表中刪掉已經被選出來當根節點的屬性
    featValues = [example[bestFeat] for example in dataSet]  #找出該屬性所有訓練數據的值(創建列表)
    uniqueVals = set(featValues)  #求出該屬性的所有值得集合(集合的元素不能重復)
    for value in uniqueVals:  #根據該屬性的值求樹的各個分支
        subLabels = labels[:]  
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)  #根據各個分支遞歸創建樹
    return myTree  #生成的樹

7、使用決策樹對測試數據進行分類的函數

#實用決策樹進行分類
def classify(inputTree, featLabels, testVec):  
    firstStr = inputTree.keys()[0]  
    secondDict = inputTree[firstStr]  
    featIndex = featLabels.index(firstStr)  
    for key in secondDict.keys():  
        if testVec[featIndex] == key:  
            if type(secondDict[key]).__name__ == 'dict':  
                classLabel = classify(secondDict[key], featLabels, testVec)  
            else: classLabel = secondDict[key]  
    return classLabel 

8、以上提供的是各個功能封裝好的函數,下面開始調用這些函數來對測試集進行分類

myDat, labels = ID3.createTrainData()  
myTree = ID3.createTree(myDat,labels) 
print myTree
bootList = ['outlook','temperature', 'humidity', 'windy'];
testList = ID3.createTestData();
for testData in testList:
    dic = ID3.classify(myTree, bootList, testData)
    print dic

注:上面代碼中使用到了一些庫,所以在前面import以下庫

from numpy import *
from scipy import *
from math import log
import operator

(兩種import的方式,from xx import * :在該文件內使用xx里的函數就像在該文件寫的函數一樣,直接使用函數名即可;而import xx:要在該文件調用xx庫的f()函數時,要使用xx.f()。主要是因為這兩種import方式使用不同的機制。有興趣的可以另外查資料了解具體背后的機制原理)

附:開始讀取txt文件中的數據的時候,讀出來的字符串有點古怪,每個單詞的各個字母間自動添加了奇怪的字符,開始以為是讀取方式有問題,找了很久。最后發現是編碼問題,老師提供的txt文件使用的是Unicode編碼,而我的編輯器里設置的是UTF-8編碼(把老師提供的txt文件(Unicode編碼)拷貝到編輯器中能正常打開(不會亂碼)所以開始沒注意到這個問題)希望以后要注意文件的編碼問題。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM