cart中回歸樹的原理和實現


前面說了那么多,一直圍繞着分類問題討論,下面我們開始學習回歸樹吧,

cart生成有兩個關鍵點

  • 如何評價最優二分結果
  • 什么時候停止和如何確定葉子節點的值

 cart分類樹采用gini系數來對二分結果進行評價,葉子節點的值使用多數表決,那么回歸樹呢?我們直接看之前的一個數據集(天氣與是否出去玩,是否出去玩改成出去玩的時間)

sunny    hot    high    FALSE    25
sunny    hot    high    TRUE    30
overcast    hot    high    FALSE    46
rainy    mild    high    FALSE    45
rainy    cool    normal    FALSE    52
rainy    cool    normal    TRUE    23
overcast    cool    normal    TRUE    43
sunny    mild    high    FALSE    35
sunny    cool    normal    FALSE    38
rainy    mild    normal    FALSE    46
sunny    mild    normal    TRUE    48
overcast    mild    high    TRUE    52
overcast    hot    normal    FALSE    44
rainy    mild    high    TRUE    30

如果用分類樹來做,結果就是這樣的,一個結果值一個節點

回歸樹切分數據集和分類樹是一樣的,那么我們如何評價一個數據集划分的好壞呢?分類樹是用gini系數衡量數據集的類別的混亂程度,同樣,我們也可以衡量數據集的回歸值的混亂程度,比較經典的是方差和標准差,由於我們需要得到和回歸值接近的值作為葉子節點的值,我們這里使用標准差吧

n是回歸值的個數,u是平均值,x是每個回歸值,S是標准差(standard deviation)

第二個問題:什么時候停止和如何確定葉子節點的值?

分類樹是特征用完或者類別都一樣;對於回歸問題回歸值都一樣的概率比較小,由於我們過程中不減少特征,所以最后肯定是一個樣本一個分支。

有人說當分支的S小於總體的5%,分支就可以結束,然后節點的值取平均值

我們看下這樣有效果不?左邊是沒有停止原始的回歸樹,右邊是加上結束條件的回歸樹,感覺效果還可以,這樣回歸樹就完成了

對比回歸樹和分類樹的實現,發現基本是就僅僅是一個函數的區別,到這里明白為什么叫分類回歸樹了嗎?

就是同樣的代碼,只需要改變一個函數,就可以實現分類或者回歸的功能的了。

下面附上回歸樹的完整代碼

# regression_tree.py
# coding:utf8
from itertools import *
from numpy import *
import operator,math
def calStDev(dataSet):
    classList = [float(example[-1]) for example in dataSet]
    n=len(classList)
    u=sum(classList)/n
    total=0
    for x in classList:
        total+=(x-u)*(x-u)
    S = math.sqrt(total)
    return S,u

def splitDataSet(dataSet, axis, values):
    retDataSet = []
    if len(values) < 2:
        for featVec in dataSet:
            if featVec[axis] == values[0]:#如果特征值只有一個,不抽取當選特征
                reducedFeatVec = featVec[:axis]     
                reducedFeatVec.extend(featVec[axis+1:])
                retDataSet.append(reducedFeatVec)
    else:
        for featVec in dataSet:
            for value in values:
                if featVec[axis] == value:#如果特征值多於一個,選取當前特征
                    retDataSet.append(featVec)

    return retDataSet    
# 傳入的是一個特征值的列表,返回特征值二分的結果
def featuresplit(features):
    count = len(features)#特征值的個數
    if count < 2:
        # print features
        # print "please check sample's features,only one feature value"
        return ((features[0],),)
    # 由於需要返回二分結果,所以每個分支至少需要一個特征值,所以要從所有的特征組合中選取1個以上的組合
    # itertools的combinations 函數可以返回一個列表選多少個元素的組合結果,例如combinations(list,2)返回的列表元素選2個的組合
    # 我們需要選擇1-(count-1)的組合
    featureIndex = range(count)
    featureIndex.pop(0) 
    combinationsList = []    
    resList=[]
    # 遍歷所有的組合
    for i in featureIndex:
        temp_combination = list(combinations(features, len(features[0:i])))
        combinationsList.extend(temp_combination)
        combiLen = len(combinationsList)
    # 每次組合的順序都是一致的,並且也是對稱的,所以我們取首尾組合集合
    # zip函數提供了兩個列表對應位置組合的功能
    resList = zip(combinationsList[0:combiLen/2], combinationsList[combiLen-1:combiLen/2-1:-1])
    return resList
# 返回最好的特征以及二分特征值
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1      #
    bestStDev = inf; bestFeature = -1;bestBinarySplit=()
    for i in range(numFeatures):        #遍歷特征
        featList = [example[i] for example in dataSet]#得到特征列
        uniqueVals = list(set(featList))       #從特征列獲取該特征的特征值的set集合
        # 三個特征值的二分結果:
        # [(('young',), ('old', 'middle')), (('old',), ('young', 'middle')), (('middle',), ('young', 'old'))]
        for split in featuresplit(uniqueVals):
            StDev = 0.0
            if len(split)==1:
                continue
            (left,right)=split
            # print split,
            # 對於每一個可能的二分結果計算gini增益
            # 左增益
            left_subDataSet = splitDataSet(dataSet, i, left)
            left_prob = len(left_subDataSet)/float(len(dataSet))
            S,u = calStDev(left_subDataSet)
            StDev += left_prob * S
            # 右增益
            right_subDataSet = splitDataSet(dataSet, i, right)
            right_prob = len(right_subDataSet)/float(len(dataSet))
            S,u = calStDev(right_subDataSet)
            StDev += right_prob * S
            # print StDev
            if (StDev < bestStDev):       #比較是否是最好的結果
                bestStDev = StDev         #記錄最好的結果和最好的特征
                bestFeature = i
                bestBinarySplit=(left,right)
    return bestFeature,bestBinarySplit,bestStDev                  

def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

def createTree(dataSet,labels,originalS):
    classList = [example[-1] for example in dataSet]
    # print dataSet
    if classList.count(classList[0]) == len(classList): 
        return classList[0]#所有的類別都一樣,就不用再划分了
    if len(dataSet) == 1: #如果沒有繼續可以划分的特征,就多數表決決定分支的類別
        return majorityCnt(classList)
    bestFeat,bestBinarySplit,bestStDev = chooseBestFeatureToSplit(dataSet)
    if bestStDev < 0.05*originalS:
        return 1.0*sum(classList)/len(classList)
    # print bestFeat,bestBinarySplit,labels
    bestFeatLabel = labels[bestFeat]
    if bestFeat==-1:
        return majorityCnt(classList)
    myTree = {bestFeatLabel:{}}
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = list(set(featValues))
    for value in bestBinarySplit:
        subLabels = labels[:]       # #拷貝防止其他地方修改
        if len(value)<2:
            del(subLabels[bestFeat])
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels,originalS)
    return myTree  

filename="regression_sample"
dataSet=[];labels=[];
with open(filename) as f:
    for line in f:
        fields=line.strip("\n").split("\t")
        t=fields[0:-1]
        t.append(int(fields[-1]))
        dataSet.append(t)
labels=["outlook","temperature","humidity","windy"]
# print dataSet
originalS,u=calStDev(dataSet)
# print originalS,u
tree= createTree(dataSet,labels,originalS)
print tree    

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM