決策樹是強大的,多功能的機器學習算法。
6.1 訓練和可視化一個決策樹
在iris數據集訓練DecisionTreeClassifier:
from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier iris = load_iris() X = iris.data[:, 2:] # petal length and width y = iris.target tree_clf = DecisionTreeClassifier(max_depth=2) tree_clf.fit(X, y)
可以將訓練好的決策樹打印出來:
6.2 預測
從根節點開始,如果滿足條件,則轉向左子樹,否則轉向又子樹。最終到達的葉子節點即為預測值。
決策樹的一個優點是幾乎不需要數據預處理,特別是不需要feature scaling或者centering。
基尼系數表示節點的純潔度。如果gini=0,說明該節點是純粹的,只包含一種類別。
第$i$個節點的Gini impurity:
$G_i = 1 - \sum_{k=1}^{n} p_{i,k}^2$
其中,$p_{i,k}$是類別$k$的樣本數在節點$i$中樣本總數所占的比例。
Scikit-Learn使用的是CART算法,產生的是二叉樹:非葉子節點只有兩個子節點。其它算法比如ID3可以生產具有更多子節點的決策樹。
模型解釋:白盒 Vs 黑盒:
決策樹的可解釋性很強,這被稱作白盒模型。相應的,隨機森林或者神經網絡是黑盒模型。
6.3 評估類別概率(Estimating Class Probabilities)
類別的預測概率,就是葉子節點中該類別所占的比例。
6.4 CART訓練算法
Scikit-Learn使用分類回歸樹(Classification And Regression Tree,CART)算法訓練決策樹。其思想很簡單:使用屬性$k$和相應的閾值$t_k$將訓練集分為兩個子集。搜索合適的$(k, t_k)$使得子集的純凈度最高。損失函數如下:
$J(k,t_k) = \frac{m_{left}}{m}G_{left} + \frac{m_{right}}{m}G_{right}$
其中,$G_{left}$、$G_{right}$是左、右子樹的純凈度,$m_{left}$、$m_{right}$是左、右子樹的樣本數。
將訓練集切分之后,會對子集繼續切分,這是一個遞歸過程。如果達到最大深度就會停止(通過max_depth超參數控制),或者已經找不到可以增大純凈度的切分(比如已經完全純凈)。還有一些控制切分停止的超參數:min_samples_split, min_samples_leaf,min_weight_fraction_leaf, and max_leaf_nodes。
這是一個貪心算法,雖不能達到最優,但可以得到一足夠優的結果。找到最優樹屬於NP完全(NP-Complete)問題,需要O(exp(m))時間,這使得即使是很小的訓練集也難以求解。
6.5 計算復雜度
決策樹預測過程,需要從根節點到達一個葉子節點,決策樹一般是近似平衡的,這一過程復雜度為$O(log_2(m))$,與樣本數無關。
訓練過程需要比較所有的特征,訓練復雜度是$(n times m log(m))$。
6.6 基尼系數還是熵(Gini Impurity or Entropy)?
熵:
$H_i = \sum_{k=1}^{n} p_{i,k}\ log(p_{i,k})$
二者差別不大,通常會得到相似的決策樹。Gini impurity計算起來更快,所有它是默認的。如果非要說它們的區別,Gini impurity傾向於將最頻繁的類別分在同一個分支,entropy傾向於生成更平衡的樹。
6.7 正則化超參數(Regularization Hyperparameters)
決策樹對訓練數據幾乎不做假設(與之相反,詳細模型明顯假設數據是線性的)。如果不進行約束,很容易造成過擬合。這種模型被稱作無參數模型(nonparametric model),這並不是真的沒有參數(通常有很多參數),而是參數個數不需要在訓練之前確定下來,這就有很高的自由度去擬合訓練數據。與之相反,比如線性模型這種參數模型,需要提前確定參數個數,所以其自由度是受限的,減少了過擬合的風險(但是增加了欠擬合的風險)。
為避免過擬合,需要在訓練時現在決策樹的自由度,這被稱作正則化。正則化超參數跟算法有關,但一般情況下至少可以限制決策樹的最大深度。在Scikit-Learn中這由max_depth超參數控制。
另外還有一些算法,不設限地訓練決策樹,訓練完成后會修剪不必要的節點。如果一個節點的子節點都是葉子節點,對該節點的拆分帶來的純凈度提升並不是統計學上有效的(statistically significant),那么其子節點就被認為是不必要的,會被刪除掉。
6.7 回歸
CART回歸損失函數:
$J(k, t_k) = \frac{m_{left}}{m} MSE_{left} + \frac{m_{right}}{m} MSE_{right}$
其中,
$MSE_{node} = \sum_{i \in node}(\hat{y}_{node} - y_{(i)})^2$
$\hat{y}_{node} = \frac{1}{m_{node}}\sum_{i \in node}y^{(i)}$
6.8 不穩定性(Instability)
決策樹雖然功能強大,但也有一些局限性。首先,決策樹的決策邊界都是正交直線(所有的切分都和某一個坐標軸垂直),這使得它們對數據集的旋轉很敏感。例如,下圖顯示了簡單的線性可分數據集,在左側,決策樹很容易將其切分。但是在右側,數據集旋轉45°,決策樹出現了不必要的繞彎彎。盡管二者都很好地擬合了訓練集,很明顯右側的模型難以很好地一般化。一個解決方案是使用PCA,它可以使訓練集旋轉到最好的方向。
此外,決策樹對訓練數據集微小的變動也會很敏感。
隨機森林通過許多決策樹的預測平均值,可以避免這一不穩定性。