從零開始寫代碼 ID3決策樹Python


視頻版地址B站:從零開始寫代碼 Python ID3決策樹算法分析與實現_嗶哩嗶哩_bilibili

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

代碼如下:

# author:會武術之白貓
# date:2021-11-6
import math

def createDataSet():
    # dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
    # labels = ['no sufacing', 'flippers']
    dataSet = [
        [1,1,2,0,1,1,0,'感冒'],
        [2,0,3,2,0,2,2,'流感'],
        [3,0,0,1,1,1,1,'流感'],
        [0,0,1,1,1,0,1,'感冒'],
        [3,1,2,2,0,2,2,'流感'],
        [0,1,2,0,1,0,0,'感冒'],
        [2,0,2,2,0,2,2,'流感'],
        [0,1,3,0,0,1,1,'感冒']]
    labels = ['發冷','喉嚨痛','咳嗽','頭痛','鼻塞','疲勞','發燒']
    return dataSet, labels

def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    # 為分類創建字典
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts.setdefault(currentLabel, 0)
        labelCounts[currentLabel] += 1

    # 計算香農墒
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt += prob * math.log2(1 / prob)
    return shannonEnt

# 定義按照某個特征進行划分的函數 splitDataSet
# 輸入三個變量(帶划分數據集, 特征,分類值)
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reduceFeatVec = featVec[:axis]
            reduceFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reduceFeatVec)
    return retDataSet  #返回不含划分特征的子集

#  定義按照最大信息增益划分數據的函數
def chooseBestFeatureToSplit(dataSet):
    numFeature = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInforGain = 0
    bestFeature = -1

    for i in range(numFeature):
        featList = [number[i] for number in dataSet] #得到某個特征下所有值
        uniqualVals = set(featList) #set無重復的屬性特征值
        newEntrogy = 0

        #求和
        for value in uniqualVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet)) #即p(t)
            newEntrogy += prob * calcShannonEnt(subDataSet) #對各子集求香農墒

        infoGain = baseEntropy - newEntrogy #計算信息增益
        #print(infoGain)

        # 最大信息增益
        if infoGain > bestInforGain:
            bestInforGain = infoGain
            bestFeature = i
    return bestFeature

# 投票表決代碼
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount.setdefault(vote, 0)
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=lambda i:i[1], reverse=True)
    return sortedClassCount[0][0]

def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    # print(dataSet)
    # print(classList)
    # 類別相同,停止划分
    if classList.count(classList[0]) == len(classList):
        return classList[0]

    # 判斷是否遍歷完所有的特征,是,返回個數最多的類別
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)

    #按照信息增益最高選擇分類特征屬性
    bestFeat = chooseBestFeatureToSplit(dataSet) #分類編號
    bestFeatLabel = labels[bestFeat]  #該特征的label
    myTree = {bestFeatLabel: {}}
    del (labels[bestFeat]) #移除該label

    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]  #子集合
        #構建數據的子集合,並進行遞歸
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree

def classify(inputTree, featLabels, testVec):
    """
    :param inputTree: 決策樹
    :param featLabels: 屬性特征標簽
    :param testVec: 測試數據
    :return: 所屬分類
    """
    firstStr = list(inputTree.keys())[0] #樹的第一個屬性
    sendDict = inputTree[firstStr]

    featIndex = featLabels.index(firstStr)
    classLabel = None
    for key in sendDict.keys():

        if testVec[featIndex] == key:
            if type(sendDict[key]).__name__ == 'dict':
                classLabel = classify(sendDict[key], featLabels, testVec)
            else:
                classLabel = sendDict[key]
    return classLabel

if __name__ == '__main__':
    dataSet, labels = createDataSet()
    r = chooseBestFeatureToSplit(dataSet)
    #print(r)
    myTree = createTree(dataSet, labels)
    print(myTree)
    #  --> {'no sufacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
    res = classify(myTree, ['發冷','喉嚨痛','咳嗽','頭痛','鼻塞','疲勞','發燒'], [1,1,2,0,1,1,0])
    print(res)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM