關於回歸樹的創建和剪枝


  之前對於樹剪枝一直感到很神奇;最近參考介紹手工寫了一下剪枝代碼,才算理解到底什么是剪枝。

  首先要明白回歸樹作為預測的模式(剪枝是針對回歸樹而言),其實是葉子節點進行預測;所以在使用回歸樹進行預測的時候,本質都是在通過每層(每個層代表一個屬性)的值的大於和小於來作為分值,進行二叉樹的遍歷。最后預測值其實葉子節點中左值或者右值;注意這里的葉子結點也是一個結構體,對於非葉子節點而言,他的左右值是一棵樹,但是對於葉子結點而言,左右值則是一個單一的數值。

  那么剪枝的原始就是找到葉子節點,如上圖所示的特征C和特征E,然后取左右值的均值,合並(merge)為一個節點。比如低於特征C,就是取值5.5,作為B樹的左節點,這樣特征C這個節點就被減掉了。

  但是在剪枝的時候注意一定要和原始場景進行比較,未剪枝前的偏差和剪枝后偏差做一個比較,看看到底哪個更優秀;如果剪枝后MSE值反而更加大了,就不要價值了。這里偏差的計算值是sum(power(yHat- y, 2))來進行比較即可。

  下面的就是剪枝的python實現:

 

 1 # 所謂剪枝即使遍歷到葉子結點,然后看一下作為預測值的葉子結點,合並左右節點(即取左右子樹平均數)為一個點
 2 # 但是需要比較一下合並之后的偏差和合並前的偏差,如果合並之后的方差變小了,則剪枝(取合並值),反之則保持原狀
 3 def prune(tree, testData):
 4     m, n = shape(testData)
 5     # 如果測試在分類(分割)過程,某一類數據為0
 6     if(m == 0): return getMean(tree)
 7     # 下面一大段其實都是在做這一件事情:深入都葉子結點
 8     # 1. 只要左右子樹中有一顆不是葉子結點,那么就以當前節點的spIndex以及spValue為分割(分類)點,對testData進行二元分類
 9     # 獲得的是二元分類的數據集left set和right set
10     if(isTree(tree["left"]) or isTree(tree["right"])):
11         lset, rset = binSplitDataset(testData, tree["spIndex"], tree["spValue"])
12     # 2. 繼續處理不是葉子結點左右子樹,對其進行遞歸prune(本質就是要深入到葉子結點為止)
13     if(isTree(tree["left"])): 
14         tree["left"] = prune(tree["left"], lset)
15     if(isTree(tree["right"])): 
16         tree["right"] = prune(tree["right"], rset)
17     
18     # 左右子樹都是葉子節點了
19     if(not isTree(tree["left"]) and not isTree(tree["right"])):
20         # 那么就以當前葉子節點的spIndex以及spValue為分割(分類)點,對testData進行二元分類
21         lset, rset = binSplitDataset(testData, tree["spIndex"], tree["spValue"])
22         # 計算測試數據集和預測值(葉子結點)之間的方差,剪枝前的偏差
23         errorNotMerge = sum(power(lset[:, -1] - tree["left"], 2)) + sum(power(rset[:, -1] - tree["right"],2))
24         treeMean = (tree["left"] + tree["right"]) / 2.0
25         # 測試數據全集和樹均值(預測值)之間的方差,剪枝后偏差
26         errorMerge = sum(power(testData[:, -1] - treeMean, 2))
27         # 看看誰的方差小,如果測試數據全集和樹均值的方差小,返回的是樹均值(葉子結點)的均值
28         if(errorMerge < errorNotMerge):
29             #print("errorMerge < errorNotMerge, treeMean is: ")
30             #print(treeMean)
31             return treeMean
32         # 如果葉子節點(預測值)的和真實值之間的方差比較小,則返回的樹,不需要剪枝
33         else:
34             #print("errorMerge > errorNotMerge, [tree] is: ")
35             #print(tree)
36             return tree
37     # 說明葉子結點剪枝效果不明顯,不需要剪枝
38     else:
39         return tree
40             

 

 

 

  那么再匯過來,如何構建一個回歸樹呢?

  構建回歸樹有幾個條件,首先要有樣本數據,葉子節點的計算方式(regLeaf),以及計算一個數據集的偏差的公式(regErr);

 

1 from numpy import mean
2 
3 # 數據集中y值的均值
4 def regLeaf(dataset):
5     return mean(dataset[:, -1])
6 
7 # 數據集中y值的方差和
8 def regErr(dataset):
9     return var(dataset[:, -1]) * shape(dataset)[0]

 

 

 

  有了這三者之后,就可以進行構建樹了。構建樹的時候,首先將會選擇一個區分度最好的特征以及特征值,做樣本分割,然后基於分割后的樣本分別構建左子樹和右子樹,這是一個遞歸的過程,發生變化的樣本,以及基於變化的樣本產生的新的分割特征以及特征值,這個遞歸過程一直到樣本數據不再可分為止,此時獲得就是一個value,這個就是葉子結點的left/right值(非葉子節點left/right仍然是一棵樹)。

 

 1 # 獲取最好的分割信息,這里包括分割的特征以及特征值,然后對數據進行分割,在以分割后數據為基礎繼續進行繼續創建樹,一直到數據無法再分割
 2 # (feature)為none為止。
 3 def createTree(dataset, leafType=regLeaf, errorType=regErr, ops=(1, 4)):
 4     feature, value = chooseBsetSplit(dataset, leafType, errorType, ops)
 5     # left/right值直接就是數字(不再是樹了)
 6     if(feature == None):
 7         return value
 8     retTree = {}
 9     retTree["spIndex"] = feature
10     retTree["spValue"] = value
11     # chooseBsetSplit其實應該一並把mat0和mat1返回,這樣這里就不需要再計算了。
12     # 但是后來看了一下代碼,返現該函數里面有的返回分支里面是沒有mat0和mat1,所以這里再計算一下也是說的通的。
13     lset, rset = bindSplitDataset(dataset, feature, value)
14     retTree["left"] = createTree(lset, leafType, errorType, ops)
15     retTree["right"] = createTree(rset, leafType, errorType, ops)
16     
17     return retTree

  下面的代碼就是獲取最佳區分特征和特征值的實現

 

 1 # 尋找最好的區分特征;為了能夠找到需要遍歷所有的特征,以及所有的特征值,然后以該特征值做分割,獲取兩個矩陣
 2 # 計算兩個矩陣的方差,不斷選出方差小的作為bestIndex以及bestValue;最后將bestIndex對應的方差和原始矩陣
 3 # 方差進行比較,如果發現最佳區分特征對應的兩分割矩陣方差明顯小,並且兩個矩陣的樣本數量都不是十分小;
 4 # 則說明該特征是OK的
 5 
 6 # 返回的feature信息可能是None,代表該節點就是葉子結點中left/right的值,該函數
 7 def chooseBsetSplit(dataset, leafType=regLeaf, errorType=regErr, ops=(1, 4)):
 8     # 可容忍的偏差,在程序開始的時候,通過errorType來計算一下dataset的y值的方差和;然后用dataset的方差
 9     # 和最好區分度的方差和做減法,如果發現差值比這個tolS要小,那么說明這次指定特征是失敗的;理想的差值是要大於tols
10     # 方差一定要比原始數據小到一定程度,這次屬性指定才有意義。
11     tolS = ops[0]
12     tolN = ops[1] # 特征划分的樣本的閾值,如果一分為二后,任何一個分類樣本數少於這個閾值,這次划分就取消
13     # 為什么==1就要退出?
14     if(len(set(dataset[:, -1].T.tolist()[0])) == 1):
15         #print("len(set(dataset[:, -1].T.tolist()[0])) == 1, return None feature")
16         return None, leafType(dataset)
17     m, n = shape(dataset)
18     # 注意這里errorType其實就是參數,這里參數就是一個函數,默認是regErr
19     S =errorType(dataset)
20     # 初始化best*
21     bestS = inf
22     bestIndex = 0
23     bestValue = 0
24     iterate_num = n-1
25     #print("iterate_num: " + str(iterate_num))
26     # 遍歷所有的特征(最后一列是結果,跳過)
27     for featureIndex in range(iterate_num):
28         #print("++++++++++++++++++++++ %d turns +++++++++++++++++++++++" % (featureIndex))
29         # 遍歷該特征的所有特征值
30         for splitValue in set(dataset[:, featureIndex].A.flatten().tolist()):
31             # 在所有訓練樣本上面(dataset)對於該特征,大於該特征值,小於特征值分別做數據分割,獲得兩個矩陣
32             mat0, mat1 = bindSplitDataset(dataset, featureIndex, splitValue)
33             # 如果分割的特征矩陣任意一個的樣本數<tolN,那么將會跳過該特征的處理,經過分割一定要達到一定的樣本數才有意義
34             # 任意一個矩陣的樣本數少說明該特征的區分度不高
35             if(shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
36                 #print("shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN; splitValue: %f, shape(mat0)[0]: %d, (shape(mat1)[0]: %d, tolN: %d" % (splitValue, shape(mat0)[0], shape(mat1)[0], tolN))
37                 continue
38             #print("*************** one ok **********************")
39             # 和leafType一樣,都是參數類型為函數,計算方差和
40             newS = errorType(mat0) + errorType(mat1)
41             # 如果方差小於bestS,則用當前的方差以及特征信息做替換;到此可以看到目標就是找到區分度高並且方差小的特征,作為最好
42             # 區分特征
43             if(newS < bestS):
44                 bestIndex = featureIndex
45                 bestS = newS
46                 bestValue = splitValue
47     # 如果S值和bestS值之差小於tolS;參見tolS的注釋。
48     if(S -bestS) < tolS:
49         #print("(S -bestS) < tolS, return feature NULL, S: %s, bestS: %s, tolS: %s" % ( str(S), str(bestS), str(tolS)))
50         return None, leafType(dataset)
51     mat0, mat1 = bindSplitDataset(dataset, bestIndex, bestValue)
52     # 這里的判斷有意義嗎?在循環體中其實已經做了這個判斷,如果不滿足也不會成為bestIndex,bestvalue;
53     if(shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
54         print("shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN")
55         return None, leafType(dataset)
56     
57     return bestIndex, bestValue

 

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM