之前對於樹剪枝一直感到很神奇;最近參考介紹手工寫了一下剪枝代碼,才算理解到底什么是剪枝。
首先要明白回歸樹作為預測的模式(剪枝是針對回歸樹而言),其實是葉子節點進行預測;所以在使用回歸樹進行預測的時候,本質都是在通過每層(每個層代表一個屬性)的值的大於和小於來作為分值,進行二叉樹的遍歷。最后預測值其實葉子節點中左值或者右值;注意這里的葉子結點也是一個結構體,對於非葉子節點而言,他的左右值是一棵樹,但是對於葉子結點而言,左右值則是一個單一的數值。
那么剪枝的原始就是找到葉子節點,如上圖所示的特征C和特征E,然后取左右值的均值,合並(merge)為一個節點。比如低於特征C,就是取值5.5,作為B樹的左節點,這樣特征C這個節點就被減掉了。
但是在剪枝的時候注意一定要和原始場景進行比較,未剪枝前的偏差和剪枝后偏差做一個比較,看看到底哪個更優秀;如果剪枝后MSE值反而更加大了,就不要價值了。這里偏差的計算值是sum(power(yHat- y, 2))來進行比較即可。
下面的就是剪枝的python實現:
1 # 所謂剪枝即使遍歷到葉子結點,然后看一下作為預測值的葉子結點,合並左右節點(即取左右子樹平均數)為一個點 2 # 但是需要比較一下合並之后的偏差和合並前的偏差,如果合並之后的方差變小了,則剪枝(取合並值),反之則保持原狀 3 def prune(tree, testData): 4 m, n = shape(testData) 5 # 如果測試在分類(分割)過程,某一類數據為0 6 if(m == 0): return getMean(tree) 7 # 下面一大段其實都是在做這一件事情:深入都葉子結點 8 # 1. 只要左右子樹中有一顆不是葉子結點,那么就以當前節點的spIndex以及spValue為分割(分類)點,對testData進行二元分類 9 # 獲得的是二元分類的數據集left set和right set 10 if(isTree(tree["left"]) or isTree(tree["right"])): 11 lset, rset = binSplitDataset(testData, tree["spIndex"], tree["spValue"]) 12 # 2. 繼續處理不是葉子結點左右子樹,對其進行遞歸prune(本質就是要深入到葉子結點為止) 13 if(isTree(tree["left"])): 14 tree["left"] = prune(tree["left"], lset) 15 if(isTree(tree["right"])): 16 tree["right"] = prune(tree["right"], rset) 17 18 # 左右子樹都是葉子節點了 19 if(not isTree(tree["left"]) and not isTree(tree["right"])): 20 # 那么就以當前葉子節點的spIndex以及spValue為分割(分類)點,對testData進行二元分類 21 lset, rset = binSplitDataset(testData, tree["spIndex"], tree["spValue"]) 22 # 計算測試數據集和預測值(葉子結點)之間的方差,剪枝前的偏差 23 errorNotMerge = sum(power(lset[:, -1] - tree["left"], 2)) + sum(power(rset[:, -1] - tree["right"],2)) 24 treeMean = (tree["left"] + tree["right"]) / 2.0 25 # 測試數據全集和樹均值(預測值)之間的方差,剪枝后偏差 26 errorMerge = sum(power(testData[:, -1] - treeMean, 2)) 27 # 看看誰的方差小,如果測試數據全集和樹均值的方差小,返回的是樹均值(葉子結點)的均值 28 if(errorMerge < errorNotMerge): 29 #print("errorMerge < errorNotMerge, treeMean is: ") 30 #print(treeMean) 31 return treeMean 32 # 如果葉子節點(預測值)的和真實值之間的方差比較小,則返回的樹,不需要剪枝 33 else: 34 #print("errorMerge > errorNotMerge, [tree] is: ") 35 #print(tree) 36 return tree 37 # 說明葉子結點剪枝效果不明顯,不需要剪枝 38 else: 39 return tree 40
那么再匯過來,如何構建一個回歸樹呢?
構建回歸樹有幾個條件,首先要有樣本數據,葉子節點的計算方式(regLeaf),以及計算一個數據集的偏差的公式(regErr);
1 from numpy import mean 2 3 # 數據集中y值的均值 4 def regLeaf(dataset): 5 return mean(dataset[:, -1]) 6 7 # 數據集中y值的方差和 8 def regErr(dataset): 9 return var(dataset[:, -1]) * shape(dataset)[0]
有了這三者之后,就可以進行構建樹了。構建樹的時候,首先將會選擇一個區分度最好的特征以及特征值,做樣本分割,然后基於分割后的樣本分別構建左子樹和右子樹,這是一個遞歸的過程,發生變化的樣本,以及基於變化的樣本產生的新的分割特征以及特征值,這個遞歸過程一直到樣本數據不再可分為止,此時獲得就是一個value,這個就是葉子結點的left/right值(非葉子節點left/right仍然是一棵樹)。
1 # 獲取最好的分割信息,這里包括分割的特征以及特征值,然后對數據進行分割,在以分割后數據為基礎繼續進行繼續創建樹,一直到數據無法再分割 2 # (feature)為none為止。 3 def createTree(dataset, leafType=regLeaf, errorType=regErr, ops=(1, 4)): 4 feature, value = chooseBsetSplit(dataset, leafType, errorType, ops) 5 # left/right值直接就是數字(不再是樹了) 6 if(feature == None): 7 return value 8 retTree = {} 9 retTree["spIndex"] = feature 10 retTree["spValue"] = value 11 # chooseBsetSplit其實應該一並把mat0和mat1返回,這樣這里就不需要再計算了。 12 # 但是后來看了一下代碼,返現該函數里面有的返回分支里面是沒有mat0和mat1,所以這里再計算一下也是說的通的。 13 lset, rset = bindSplitDataset(dataset, feature, value) 14 retTree["left"] = createTree(lset, leafType, errorType, ops) 15 retTree["right"] = createTree(rset, leafType, errorType, ops) 16 17 return retTree
下面的代碼就是獲取最佳區分特征和特征值的實現
1 # 尋找最好的區分特征;為了能夠找到需要遍歷所有的特征,以及所有的特征值,然后以該特征值做分割,獲取兩個矩陣 2 # 計算兩個矩陣的方差,不斷選出方差小的作為bestIndex以及bestValue;最后將bestIndex對應的方差和原始矩陣 3 # 方差進行比較,如果發現最佳區分特征對應的兩分割矩陣方差明顯小,並且兩個矩陣的樣本數量都不是十分小; 4 # 則說明該特征是OK的 5 6 # 返回的feature信息可能是None,代表該節點就是葉子結點中left/right的值,該函數 7 def chooseBsetSplit(dataset, leafType=regLeaf, errorType=regErr, ops=(1, 4)): 8 # 可容忍的偏差,在程序開始的時候,通過errorType來計算一下dataset的y值的方差和;然后用dataset的方差 9 # 和最好區分度的方差和做減法,如果發現差值比這個tolS要小,那么說明這次指定特征是失敗的;理想的差值是要大於tols 10 # 方差一定要比原始數據小到一定程度,這次屬性指定才有意義。 11 tolS = ops[0] 12 tolN = ops[1] # 特征划分的樣本的閾值,如果一分為二后,任何一個分類樣本數少於這個閾值,這次划分就取消 13 # 為什么==1就要退出? 14 if(len(set(dataset[:, -1].T.tolist()[0])) == 1): 15 #print("len(set(dataset[:, -1].T.tolist()[0])) == 1, return None feature") 16 return None, leafType(dataset) 17 m, n = shape(dataset) 18 # 注意這里errorType其實就是參數,這里參數就是一個函數,默認是regErr 19 S =errorType(dataset) 20 # 初始化best* 21 bestS = inf 22 bestIndex = 0 23 bestValue = 0 24 iterate_num = n-1 25 #print("iterate_num: " + str(iterate_num)) 26 # 遍歷所有的特征(最后一列是結果,跳過) 27 for featureIndex in range(iterate_num): 28 #print("++++++++++++++++++++++ %d turns +++++++++++++++++++++++" % (featureIndex)) 29 # 遍歷該特征的所有特征值 30 for splitValue in set(dataset[:, featureIndex].A.flatten().tolist()): 31 # 在所有訓練樣本上面(dataset)對於該特征,大於該特征值,小於特征值分別做數據分割,獲得兩個矩陣 32 mat0, mat1 = bindSplitDataset(dataset, featureIndex, splitValue) 33 # 如果分割的特征矩陣任意一個的樣本數<tolN,那么將會跳過該特征的處理,經過分割一定要達到一定的樣本數才有意義 34 # 任意一個矩陣的樣本數少說明該特征的區分度不高 35 if(shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): 36 #print("shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN; splitValue: %f, shape(mat0)[0]: %d, (shape(mat1)[0]: %d, tolN: %d" % (splitValue, shape(mat0)[0], shape(mat1)[0], tolN)) 37 continue 38 #print("*************** one ok **********************") 39 # 和leafType一樣,都是參數類型為函數,計算方差和 40 newS = errorType(mat0) + errorType(mat1) 41 # 如果方差小於bestS,則用當前的方差以及特征信息做替換;到此可以看到目標就是找到區分度高並且方差小的特征,作為最好 42 # 區分特征 43 if(newS < bestS): 44 bestIndex = featureIndex 45 bestS = newS 46 bestValue = splitValue 47 # 如果S值和bestS值之差小於tolS;參見tolS的注釋。 48 if(S -bestS) < tolS: 49 #print("(S -bestS) < tolS, return feature NULL, S: %s, bestS: %s, tolS: %s" % ( str(S), str(bestS), str(tolS))) 50 return None, leafType(dataset) 51 mat0, mat1 = bindSplitDataset(dataset, bestIndex, bestValue) 52 # 這里的判斷有意義嗎?在循環體中其實已經做了這個判斷,如果不滿足也不會成為bestIndex,bestvalue; 53 if(shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): 54 print("shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN") 55 return None, leafType(dataset) 56 57 return bestIndex, bestValue