一、CART決策樹模型概述(Classification And Regression Trees)
決策樹是使用類似於一棵樹的結構來表示類的划分,樹的構建可以看成是變量(屬性)選擇的過程,內部節點表示樹選擇那幾個變量(屬性)作為划分,每棵樹的葉節點表示為一個類的標號,樹的最頂層為根節點。
決策樹是通過一系列規則對數據進行分類的過程。它提供一種在什么條件下會得到什么值的類似規則的方法。決策樹算法屬於有指導的學習,即原數據必須包含預測變量和目標變量。決策樹分為分類決策樹(目標變量為分類型數值)和回歸決策樹(目標變量為連續型變量)。分類決策樹葉節點所含樣本中,其輸出變量的眾數就是分類結果;回歸樹的葉節點所含樣本中,其輸出變量的平均值就是預測結果。
決策樹是一種倒立的樹結構,它由內部節點、葉子節點和邊組成。其中最上面的一個節點叫根節點。 構造一棵決策樹需要一個訓練集,一些例子組成,每個例子用一些屬性(或特征)和一個類別標記來描述。構造決策樹的目的是找出屬性和類別間的關系,一旦這種關系找出,就能用它來預測將來未知類別的記錄的類別。這種具有預測功能的系統叫決策樹分類器。
決策樹有非常良好的優點:
1)決策樹的夠造不需要任何領域知識,就是簡單的IF...THEN...思想 ;
2)決策樹能夠很好的處理高維數據,並且能夠篩選出重要的變量;
3)由決策樹產生的結果是易於理解和掌握的;
4)決策樹在運算過程中也是非常迅速的;
5)一般而言,決策樹還具有比較理想的預測准確率。
CART決策樹又稱分類回歸樹,當數據集的因變量為連續性數值時,該樹算法就是一個回歸樹,可以用葉節點觀察的均值作為預測值;當數據集的因變量為離散型數值時,該樹算法就是一個分類樹,可以很好的解決分類問題。但需要注意的是,該算法是一個二叉樹,即每一個非葉節點只能引伸出兩個分支,所以當某個非葉節點是多水平(2個以上)的離散變量時,該變量就有可能被多次使用。
決策樹算法中包含最核心的兩個問題,即特征選擇和剪枝:
關於特征選擇目前比較流行的方法是信息增益、增益率、基尼系數和卡方檢驗,下文就先介紹基於基尼系數的特征選擇,因為本文所描述的CART決策樹就是基於基尼系數選擇特征的;
關於剪枝問題,主要分預剪枝和后剪枝,預剪枝是在樹還沒有生長之前就限定了樹的層數、葉節點觀測數量等,而后剪枝是在樹得到充分生長后,基於損失矩陣或復雜度方法實施剪枝,下文將采用后剪枝的方法對樹進行修正。
二、決策樹的核心問題
決策樹核心問題有二:一是利用Training Data完成決策樹的生成過程;二是利用Testing Data完成對決策樹的精簡過程。即前面我們提到的,生成的推理規則往往過多,精簡是必需的。
1)決策樹的生長
決策樹生長過程的本質是對Training Data反復分組(分枝)的過程,當數據分組(分枝)不再有意義——注意,什么叫分組不再有意義——時,決策樹生成過程停止。因此,決策樹生長的核心算法是確定數據分析的標准,即分枝標准。
何為有意義呢?注意,當決策樹分枝后結果差異不再顯著下降,則繼續分組沒有意義。也就是說,我們分組的目的,是為了讓輸出變量在差異上盡量小,到達葉節點時,不同葉節點上的輸出變量為相同類別,或達到用戶指定的決策樹停止生成的標准。
這樣,分枝准則涉及到兩方面問題:1、如果從眾多輸入變量中選擇最佳分組變量;2、如果從分組變量的眾多取值中找到最佳分割點。不同的決策樹算法,如C4.5、C5.0、Chaid、Quest、Cart采用了不同策略。
2)決策樹的修剪
完整的決策樹並不是一棵分類預測新數據對象的最佳樹。其原因是完整的決策樹對Training Data描述過於“精確”。我們知道,隨着決策樹的生長,決策樹分枝時所處理的樣本數量在不斷減少,決策樹對數據總體珠代表程度在不斷下降。在對根節點進行分枝時,處理的是全部樣本,再往下分枝,則是處理的不同分組下的分組下的樣本。可見隨着決策樹的生長和樣本數量的不斷減少,越深層處的節點所體現的數據特征就越個性化,可能出現如上推理規則:“年收入大於50000元且年齡大於50歲且姓名叫張三的人購買了此產品”。這種過度學習從而精確反映Training Data特征,失去一般代表性而無法應用於新數據分類預測的現象,叫過度擬合(Overfitting)或過度學習。那我們應該怎么辦呢?修剪!
常用的修剪技術有預修剪(Pre-Pruning)和后修剪(Post-Pruning)。
Pre-Pruning可以事先指定決策樹的最大深度,或最小樣本量,以防止決策樹過度生長。前提是用戶對變量聚會有較為清晰的把握,且要反復嘗試調整,否則無法給出一個合理值。注意,決策樹生長過深無法預測新數據,生長過淺亦無法預測新數據。
Post-pruning是一個邊修剪邊檢驗的過程,即在決策樹充分生長的基礎上,設定一個允許的最大錯誤率,然后一邊修剪子樹,一邊計算輸出結果的精度或誤差。當錯誤率高於最大值后,立即停止剪枝。
基於Training Data(訓練集)的Post-Pruning(剪枝)應該使用Testing Data(測試集)。
決策樹中的C4.5、C5.0、CHAID、CART和QUEST都使用了不同 剪枝策略。
案例、使用rpart()回歸樹分析糖尿病的血液化驗指標
install.packages("rpart")
library("rpart")
install.packages("rpart.plot")
library(rpart.plot)
1、主要應用函數:
1)構建回歸樹的函數:rpart()
rpart(formula, data, weights, subset,na.action = na.rpart, method,
model = FALSE, x = FALSE, y = TRUE, parms, control, cost, ...)
主要參數說明:
fomula:回歸方程形式:例如 y~x1+x2+x3。
data:數據:包含前面方程中變量的數據框(dataframe)。
na.action:缺失數據的處理辦法:默認辦法是刪除因變量缺失的觀測而保留自變量缺失的觀測。
method:根據樹末端的數據類型選擇相應變量分割方法,本參數有四種取值:連續型“anova”;離散型“class”;計數型(泊松過程)“poisson”;生存分析型“exp”。程序會根據因變量的類型自動選擇方法,但一般情況下最好還是指明本參數,以便讓程序清楚做哪一種樹模型。
parms:用來設置三個參數:先驗概率、損失矩陣、分類純度的度量方法。
cost:損失矩陣,在剪枝的時候,葉子節點的加權誤差與父節點的誤差進行比較,考慮損失矩陣的時候,從將“減少-誤差”調整為“減少-損失”
control:控制每個節點上的最小樣本量、交叉驗證的次數、復雜性參量:即cp:complexitypamemeter,這個參數意味着對每一步拆分,模型的擬合優度必須提高的程度,等等。rpart.control對樹進行一些設置
xval是10折交叉驗證
minsplit是最小分支節點數,這里指大於等於20,那么該節點會繼續分划下去,否則停止
minbucket:葉子節點最小樣本數;maxdepth:樹的深度
cp全稱為complexity parameter,指某個點的復雜度,對每一步拆分,模型的擬合優度必須提高的程度,用來節省剪枝浪費的不必要的時間。
2)進行剪枝的函數:prune()
prune(tree, cp, ...)
主要參數說明:
tree:一個回歸樹對象,常是rpart()的結果對象。
cp:復雜性參量,指定剪枝采用的閾值。cp全稱為complexity parameter,指某個點的復雜度,對每一步拆分,模型的擬合優度必須提高的程度,用來節省剪枝浪費的不必要的時間。
二、特征選擇
CART算法的特征選擇就是基於基尼系數得以實現的,其選擇的標准就是每個子節點達到最高的純度,即落在子節點中的所有觀察都屬於同一個分類。下面簡單介紹一下有關基尼系數的計算問題:
假設數據集D中的因變量有m個水平,即數據集可以分成m類群體,則數據集D的基尼系數可以表示為:
由於CART算法是二叉樹形式,所以一個多水平(m個水平)的離散變量(自變量)可以把數據集D划分為2^m-2種可能。舉個例子也許能夠明白:如果年齡段可分為{青年,中年,老年},則其子集可以是{青年,中年,老年}、{青年,中年}、{青年,老年}、{中年,老年}、{青年}、{中年}、{老年}、{}。其中{青年,中年,老年}和空集{}為無意義的Split,所以6=2^3-2。
對於一個離散變量來說,需要計算每個分區不純度的加權和,即對於變量A來說,D的基尼系數為:
對於一個連續變量來說,需要將排序后的相鄰值的中點作為閾值(分裂點),同樣使用上面的公式計算每一個分區不純度的加權和。
根據特征選擇的標准,只有使每個變量的每種分區的基尼系數達到最小,就可以確定該變量下的閾值作為分裂變量和分裂點。如果這部分讀的不易理解的話,可參考《數據挖掘:概念與技術》一書,書中有關於計算的案例。
三、剪枝
剪枝是為了防止模型過擬合,而更加適合樣本外的預測。一般決策樹中的剪枝有兩種方式,即預剪枝和后剪枝,而后剪枝是運用最為頻繁的方法。后剪枝中又分為損失矩陣剪枝法和復雜度剪枝法,對於損失矩陣剪枝法而言,是為了給錯誤預測一個懲罰系數,使其在一定程度上減少預測錯誤的情況;對於復雜度剪枝法而言,就是把樹的復雜度看作葉節點的個數和樹的錯誤率(錯誤分類觀察數的比例)的函數。這里講解的有點抽象,下面我們通過一個簡單的例子來說明后剪枝的作用。
四、案例分享
以“知識的掌握程度”數據為例,說說決策樹是如何實現數據的分類的(數據來源
:http://archive.ics.uci.edu/ml/datasets/User+Knowledge+Modeling)。
該數據集通過5個維度來衡量知識的掌握程度,它們分別是:
STG:目標科目的學習時長程度;
SCG:對目標科目的重復學習程度;
STR:其他相關科目的學習時長程度;
LPR:其他相關科目的考試成績;
PEG:目標科目的考試成績。
知識的掌握程度用UNS表示,它有4個水平,即Very Low、Low、Middle、High。
#讀取外部文件
Train <- read.csv(file = file.choose())
Test <- read.csv(file = file.choose())
#加載CART算法所需的擴展包,並構建模型
library(rpart)
fit <- rpart(UNS ~ ., data = Train)
#查看模型輸出的規則
fit
上面的輸出規則看起來有點眼花繚亂,我們嘗試用決策樹圖來描述產生的具體規則。由於rpart包中有plot函數實現決策樹圖的繪制,但其顯得很難看,我們下面使用rpart.plot包來繪制比較好看的決策樹圖:
#加載並繪制決策樹圖
library(rpart.plot)
rpart.plot(fit, branch = 1, branch.type = 1, type = 2, extra = 102,shadow.col='gray', box.col='green',border.col='blue', split.col='red',main="CART決策樹")
上圖可一目了然的查看具體的輸出規則,如根節點有258個觀測,其中Middle有88個,當PEG>=0.68時,節點內有143個觀測,其中Middle有78個,當PEG>=0.12且PEG<0.34時,節點內有115個觀察,其中Low有81個,以此類推還可以得出其他規則。
#將模型用於預測
Pred <- predict(object = fit, newdata = Test[,-6], type = 'class')
#構建混淆矩陣
CM <- table(Test[,6], Pred)
CM
#計算模型的預測准確率
Accuracy <- sum(diag(CM))/sum(CM)
Accuracy
結果顯示,模型在測試集中的預測能力超過91%。但模型的預測准確率還有提升的可能嗎?下面我們對模型進行剪枝操作,具體分損失矩陣法剪枝和復雜度剪枝:
根據混淆矩陣的顯示結果,發現High的預測率達100%(39/39),Low的預測率達91.3%(42/46),Middle的預測率達88.2%(30/34),very_low的預測率達80.8(21/26)。如果希望提升very_low的預測准確率的話就需要將其懲罰值提高,經嘗試調整,構建如下損失矩陣:
vec = c(0,1,1,1,1,0,1,1,1,2,0,1,1,3.3,1,0)
cost = matrix(vec, nrow = 4, byrow = TRUE)
cost
fit2 = rpart(UNS ~ ., data = Train, parms = list(loss = cost))
Pred2 = predict(fit2, Test[,-6], type = 'class')
CM2 <- table(Test[,6], Pred2)
CM2
Accuracy2 <- sum(diag(CM2))/sum(CM2)
Accuracy2
准確率提升了1.4%,且在保證High、Low、Middle准確率不變的情況下,提升了very_low的准確率88.5%,原來為80.8%。
下面再采用復雜度方法進行剪枝,先來看看原模型的CP值:
printcp(fit)
復雜度剪枝法滿足的條件是,在預測誤差(xerror)盡量小的情況下(不一定是最小值,而是允許最小誤差的一個標准差(xstd)之內),選擇盡量小的cp值。這里選擇cp=0.01。
fit3 = prune(fit, cp = 0.01)
Pred3 = predict(fit3, Test[,-6], type = 'class')
CM3 <- table(Test[,6], Pred3)
CM3
Accuracy3 <- sum(diag(CM3))/sum(CM3)
Accuracy3
很顯然,模型的准確率並沒有得到提升,因為這里滿足條件的cp值為0.01,而函數rpart()默認的cp值就是0.01,故模型fit3的結果與fit一致。
確定遞歸建樹的停止條件:否則會使節點過多,導致過擬合。
1. 每個子節點只有一種類型的記錄時停止,這樣會使得節點過多,導致過擬合。
2. 可行方法:當前節點中的記錄數低於一個閾值時,那就停止分割。
過擬合原因:
(1)噪音數據,某些節點用噪音數據作為了分割標准。
(2)缺少代表性的數據,訓練數據沒有包含所有具有代表性的數據,導致某類數據無法很好匹配。
(3)還就是上面的停止條件設置不好。