1.引言
學過數據結構的同學對二叉樹應該不陌生:二叉樹是一個連通的無環圖,每個節點最多有兩個子樹的樹結構。如下圖(一)就是一個深度k=3的二叉樹。
(圖一) (圖二)
二元決策樹與此類似。不過二元決策樹是基於屬性做一系列二元(是/否)決策。每次決策從下面的兩種決策中選擇一種,然后又會引出另外兩種決策,依次類推直到葉子節點:即最終的結果。也可以理解為是對二叉樹的遍歷,或者很多層的if-else嵌套。
這里需要特別說明的是:二元決策樹中的深度算法與二叉樹中的深度算法是不一樣的。二叉樹的深度是指有多少層,而二元決策樹的深度是指經過多少層計算。以上圖(一)為例,二叉樹的深度k=3,而在二元決策樹中深度k=2。
圖二就是一個二元決策樹的例子,其中最關鍵的是如何選擇切割點:即X[0]<=-0.075中的-0.0751是如何選擇出來的?
2.二元決策樹切割點的選擇
切割點的選擇是二元決策樹最核心的部分,其基本思路是:遍歷所有數據,嘗試每個數據作為分割點,並計算此時左右兩側的數據的離差平方和,並從中找到最小值,然后找到離差平方和最小時對應的數據,它就是最佳分割點。下面通過具體的代碼講解這一過程:
import numpy import matplotlib.pyplot as plot #建立一個100數據的測試集 nPoints = 100 #x的取值范圍:-0.5~+0.5的nPoints等分 xPlot = [-0.5+1/nPoints*i for i in range(nPoints + 1)] #y值:在x的取值上加一定的隨機值或者叫噪音數據 #設置隨機數算法生成數據時的開始值,保證隨機生成的數值一致 numpy.random.seed(1) ##隨機生成寬度為0.1的標准正態分布的數值 ##上面的設置是為了保證numpy.random這步生成的數據一致 y = [s + numpy.random.normal(scale=0.1) for s in xPlot] #離差平方和列表 sumSSE = [] for i in range(1, len(xPlot)): #以xPlot[i]為界,分成左側數據和右側數據 lhList = list(xPlot[0:i]) rhList = list(xPlot[i:len(xPlot)]) #計算每側的平均值 lhAvg = sum(lhList) / len(lhList) rhAvg = sum(rhList) / len(rhList) #計算每側的離差平方和 lhSse = sum([(s - lhAvg) * (s - lhAvg) for s in lhList]) rhSse = sum([(s - rhAvg) * (s - rhAvg) for s in rhList]) #統計總的離差平方和,即誤差和 sumSSE.append(lhSse + rhSse) ##找到最小的誤差和 minSse = min(sumSSE) ##產生最小誤差和時對應的數據索引 idxMin = sumSSE.index(minSse) ##打印切割點數據及切割點位置 print("切割點位置:"+str(idxMin)) ##49 print("切割點數據:"+str(xPlot[idxMin]))##-0.010000000000000009 ##繪制離差平方和隨切割點變化而變化的曲線 plot.plot(range(1, len(xPlot)), sumSSE) plot.xlabel('Split Point Index') plot.ylabel('Sum Squared Error') plot.show()
3.使用二元決策樹擬合數據
這里使用sklearn.tree.DecisionTreeRegressor函數。下面只顯示了主要代碼,數據生成部分同上:
from sklearn import tree from sklearn.tree import DecisionTreeRegressor ##使用二元決策樹擬合數據:深度為1 ##說明numpy.array(xPlot).reshape(1, -1):這是傳入參數的需要:list->narray simpleTree = DecisionTreeRegressor(max_depth=1) simpleTree.fit(numpy.array(xPlot).reshape(-1,1), numpy.array(y).reshape(-1,1)) ##讀取訓練后的預測數據 y_pred = simpleTree.predict(numpy.array(xPlot).reshape(-1,1)) ##繪圖 plot.figure() plot.plot(xPlot, y, label='True y') plot.plot(xPlot, y_pred, label='Tree Prediction ', linestyle='--') plot.legend(bbox_to_anchor=(1,0.2)) plot.axis('tight') plot.xlabel('x') plot.ylabel('y') plot.show()
結果如下圖:
(圖三)
當深度依次為2(圖四)、深度為6(圖5)時的結果:
(圖四) (圖五)
4.二元決策樹的過度擬合
二元決策樹同普通最小二乘法一樣,都存在擬合過度的情況,如圖五所示,幾乎看不到預測值的曲線,這就是擬合過度了。判斷是否擬合過度有兩種方法:
1)觀察結果圖。這個很好理解,就是直接看繪制的對比圖。
2)比較決策樹終止節點的數目與數據的規模。生產圖五的曲線的深度是6(最深會有7層),即會有26=64個節點,而數據集中一共才有100個數據,也就是說有很多節點是只包括一個數據的。
5.二元決策樹深度的選擇
一般是通過不同深度二元決策樹的交叉驗證(前面已講過原理)來確定最佳深度的,基本思路:
1)確定深度列表:
2)設置采用幾折交叉驗證
3)計算每折交叉驗證時的樣本外數據的均方誤差
4)繪圖,觀察結果
下面就通過深度分別為1~7的10折交叉驗證來檢驗下最佳深度。
import numpy import matplotlib.pyplot as plot from sklearn import tree from sklearn.tree import DecisionTreeRegressor #建立一個100數據的測試集 nPoints = 100 #x的取值范圍:-0.5~+0.5的nPoints等分 xPlot = [-0.5+1/nPoints*i for i in range(nPoints + 1)] #y值:在x的取值上加一定的隨機值或者叫噪音數據 #設置隨機數算法生成數據時的開始值,保證隨機生成的數值一致 numpy.random.seed(1) ##隨機生成寬度為0.1的標准正態分布的數值 ##上面的設置是為了保證numpy.random這步生成的數據一致 y = [s + numpy.random.normal(scale=0.1) for s in xPlot] ##測試數據的長度 nrow = len(xPlot) ##設置二元決策樹的深度列表 depthList = [1, 2, 3, 4, 5, 6, 7] ##每個深度下的離差平方和 xvalMSE = [] ##設置n折交叉驗證 nxval = 10 ##外層循環:深度循環 for iDepth in depthList: ##每個深度下的樣本外均方誤差 oosErrors =0 ##內層循環:交叉驗證循環 for ixval in range(nxval+1): #定義訓練集和測試集標簽 xTrain = [] #訓練集 xTest = [] #測試集 yTrain = [] #訓練集標簽 yTest = [] #測試集標簽 for a in range(nrow): ##如果采用a%ixval==0的方式寫,會有除數為0的錯誤 if a%nxval != ixval%nxval: xTrain.append(xPlot[a]) yTrain.append(y[a]) else : xTest.append(xPlot[a]) yTest.append(y[a]) ##深度為max_depth=iDepth的訓練 treeModel = DecisionTreeRegressor(max_depth=iDepth) ##轉換參數類型 treeModel.fit(numpy.array(xTrain).reshape(-1, 1), numpy.array(yTrain).reshape(-1, 1)) ##讀取預測值:使用測試集獲取樣本外誤差 treePrediction = treeModel.predict(numpy.array(xTest).reshape(-1, 1)) ##離差列表:使用測試標簽獲取樣本外誤差 error = [yTest[r] - treePrediction[r] for r in range(len(yTest))] ##每個深度下的樣本外均方誤差和 oosErrors += sum([e * e for e in error]) #計算每個深度下的樣本外平均離差平方和 mse = oosErrors/nrow ##添加到離差平方和列表 xvalMSE.append(mse) ##繪圖---樣本外離差和的平方平均值隨深度變化的曲線 plot.plot(depthList, xvalMSE) plot.axis('tight') plot.xlabel('Tree Depth') plot.ylabel('Mean Squared Error') plot.show()
結果如圖:
(圖六)
從圖中可以看出,當深度為3時的效果最好,下面我們把深度調成3,觀察結果(為了效果調整了上面代碼的顏色值):
(圖七)