連續值的CART(分類回歸樹)原理和實現


上一篇我們學習和實現了CART(分類回歸樹),不過主要是針對離散值的分類實現,下面我們來看下連續值的cart分類樹如何實現

思考連續值和離散值的不同之處:

二分子樹的時候不同:離散值需要求出最優的兩個組合,連續值需要找到一個合適的分割點把特征切分為前后兩塊

這里不考慮特征的減少問題

切分數據的不同:根據大於和小於等於切分數據集

def splitDataSet(dataSet, axis, value,threshold):
    retDataSet = []
    if threshold == 'lt':
        for featVec in dataSet:
            if featVec[axis] <= value:
                retDataSet.append(featVec)
    else:
        for featVec in dataSet:
            if featVec[axis] > value:
                retDataSet.append(featVec)

    return retDataSet

選擇最好特征的最好特征值

def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1      
    bestGiniGain = 1.0; bestFeature = -1;bsetValue=""
    for i in range(numFeatures):        #遍歷特征
        featList = [example[i] for example in dataSet]#得到特征列
        uniqueVals = list(set(featList))       #從特征列獲取該特征的特征值的set集合
        uniqueVals.sort()
        for value in uniqueVals:# 遍歷所有的特征值
            GiniGain = 0.0
            # 左增益
            left_subDataSet = splitDataSet(dataSet, i, value,'lt')
            left_prob = len(left_subDataSet)/float(len(dataSet))
            GiniGain += left_prob * calGini(left_subDataSet)
            # print left_prob,calGini(left_subDataSet),
            # 右增益
            right_subDataSet = splitDataSet(dataSet, i, value,'gt')
            right_prob = len(right_subDataSet)/float(len(dataSet))
            GiniGain += right_prob * calGini(right_subDataSet)
            # print right_prob,calGini(right_subDataSet),         
            # print GiniGain
            if (GiniGain < bestGiniGain):       #比較是否是最好的結果
                bestGiniGain = GiniGain         #記錄最好的結果和最好的特征
                bestFeature = i
                bsetValue=value
    return bestFeature,bsetValue

生成cart:總體上和離散值的差不多,主要差別在於分支的值要加上大於或者小於等於號

def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]
    # print dataSet
    if classList.count(classList[0]) == len(classList): 
        return classList[0]#所有的類別都一樣,就不用再划分了
    if len(dataSet) == 1: #如果沒有繼續可以划分的特征,就多數表決決定分支的類別
        return majorityCnt(classList)
    bestFeat,bsetValue = chooseBestFeatureToSplit(dataSet)
    # print bestFeat,bsetValue,labels
    bestFeatLabel = labels[bestFeat]
    if bestFeat==-1:
        return majorityCnt(classList)
    myTree = {bestFeatLabel:{}}
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = list(set(featValues))
    subLabels = labels[:]
    # print bsetValue
    myTree[bestFeatLabel][bestFeatLabel+'<='+str(round(float(bsetValue),3))] = createTree(splitDataSet(dataSet, bestFeat, bsetValue,'lt'),subLabels)
    myTree[bestFeatLabel][bestFeatLabel+'>'+str(round(float(bsetValue),3))] = createTree(splitDataSet(dataSet, bestFeat, bsetValue,'gt'),subLabels)
    return myTree  

 

我們看下連續值的cart大概是什么樣的(數據集是我們之前用的100個點的數據集)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM