上一篇我們學習和實現了CART(分類回歸樹),不過主要是針對離散值的分類實現,下面我們來看下連續值的cart分類樹如何實現
思考連續值和離散值的不同之處:
二分子樹的時候不同:離散值需要求出最優的兩個組合,連續值需要找到一個合適的分割點把特征切分為前后兩塊
這里不考慮特征的減少問題
切分數據的不同:根據大於和小於等於切分數據集
def splitDataSet(dataSet, axis, value,threshold): retDataSet = [] if threshold == 'lt': for featVec in dataSet: if featVec[axis] <= value: retDataSet.append(featVec) else: for featVec in dataSet: if featVec[axis] > value: retDataSet.append(featVec) return retDataSet
選擇最好特征的最好特征值
def chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0]) - 1 bestGiniGain = 1.0; bestFeature = -1;bsetValue="" for i in range(numFeatures): #遍歷特征 featList = [example[i] for example in dataSet]#得到特征列 uniqueVals = list(set(featList)) #從特征列獲取該特征的特征值的set集合 uniqueVals.sort() for value in uniqueVals:# 遍歷所有的特征值 GiniGain = 0.0 # 左增益 left_subDataSet = splitDataSet(dataSet, i, value,'lt') left_prob = len(left_subDataSet)/float(len(dataSet)) GiniGain += left_prob * calGini(left_subDataSet) # print left_prob,calGini(left_subDataSet), # 右增益 right_subDataSet = splitDataSet(dataSet, i, value,'gt') right_prob = len(right_subDataSet)/float(len(dataSet)) GiniGain += right_prob * calGini(right_subDataSet) # print right_prob,calGini(right_subDataSet), # print GiniGain if (GiniGain < bestGiniGain): #比較是否是最好的結果 bestGiniGain = GiniGain #記錄最好的結果和最好的特征 bestFeature = i bsetValue=value return bestFeature,bsetValue
生成cart:總體上和離散值的差不多,主要差別在於分支的值要加上大於或者小於等於號
def createTree(dataSet,labels): classList = [example[-1] for example in dataSet] # print dataSet if classList.count(classList[0]) == len(classList): return classList[0]#所有的類別都一樣,就不用再划分了 if len(dataSet) == 1: #如果沒有繼續可以划分的特征,就多數表決決定分支的類別 return majorityCnt(classList) bestFeat,bsetValue = chooseBestFeatureToSplit(dataSet) # print bestFeat,bsetValue,labels bestFeatLabel = labels[bestFeat] if bestFeat==-1: return majorityCnt(classList) myTree = {bestFeatLabel:{}} featValues = [example[bestFeat] for example in dataSet] uniqueVals = list(set(featValues)) subLabels = labels[:] # print bsetValue myTree[bestFeatLabel][bestFeatLabel+'<='+str(round(float(bsetValue),3))] = createTree(splitDataSet(dataSet, bestFeat, bsetValue,'lt'),subLabels) myTree[bestFeatLabel][bestFeatLabel+'>'+str(round(float(bsetValue),3))] = createTree(splitDataSet(dataSet, bestFeat, bsetValue,'gt'),subLabels) return myTree
我們看下連續值的cart大概是什么樣的(數據集是我們之前用的100個點的數據集)