目錄
3.1 GBDT的損失函數
3.1.1 梯度提升回歸樹損失函數介紹
3.1.2 梯度提升分類樹損失函數介紹
3.2 GBDT回歸算法描述
3.2.1 平方損失GBDT算法描述
3.2.2 絕對損失GBDT算法描述
3.2.3 huber損失GBDT算法描述
3.3.1 log損失GBDT的二分類算法描述
3.3.2 log損失GBDT的多分類算法描述
4.1 證明:損失函數為平方損失時,葉節點的最佳預測為葉節點殘差均值
4.2 證明:損失函數為絕對損失時,葉節點的最佳預測為葉節點殘差的中位數。
五. 參考文獻/博文
一、GBDT
在介紹AdaBoost的時候我們講到了,AdaBoost算法是模型為加法模型,損失函數為指數函數(針對分類為題),學習算法為前向分步算法時的分類問題。而GBDT算法是模型為加法模型,學習算法為前向分步算法,基函數為CART樹(樹是回歸樹),損失函數為平方損失函數的回歸問題,為指數函數的分類問題和為一般損失函數的一般決策問題。在針對基學習器的不足上,AdaBoost算法是通過提升錯分數據點的權重來定位模型的不足,而梯度提升算法是通過算梯度來定位模型的不足。
當GBDT的損失函數是平方損失(這需要回到損失函數類型,以及損失函數對應的優化問題上)時,即時,則負梯度
,而
即為我們所說的殘差,而我們的GBDT的思想就是在每次迭代中擬合殘差來學習一個弱學習器。而殘差的方向即為我們全局最優的方向。但是當損失函數不為平方損失時,我們該如何擬合弱學習器呢?大牛Friedman提出使用損失函數負梯度的方向代替殘差方向,我們稱損失函數負梯度為偽殘差。而偽殘差的方向即為我們局部最優的方向。所以在GBDT中,當損失函數不為平方損失時,用每次迭代的局部最優方向代替全局最優方向(這種方法是不是很熟悉?)。
說了這么多,現在舉個例子來看看GBDT是如何擬合殘差來學習弱學習器的。我們可以證明,當損失函數為平方損失時,葉節點中使平方損失誤差達到最小值的是葉節點中所有值的均值;而當損失函數為絕對值損失時,葉節點中使絕對損失誤差達到最小值的是葉節點中所有值的中位數。相關證明將在最后的附錄中給出。
訓練集是4個人,A,B,C,D年齡分別是14,16,24,26。樣本中有購物金額、上網時長、經常到百度知道提問等特征。提升樹的過程如下:
從上圖可以看出,第一棵樹建立的時候使用的是原始數據,而后每一棵樹建立使用的是前n-1次的殘差來擬合弱學習器。
下面,我們就來簡單的介紹一下GBDT的基本原理和算法描述。
二. GBDT回歸樹基本模版
梯度提升算法的回歸樹基本模版,如下所示:
輸入:訓練數據集,損失函數為
輸出:回歸樹
(1)初始化:(估計使損失函數極小化的常數值,它是只有一個根節點的樹(樹不都一般只有一個根節點嗎),一般平方損失函數為節點的均值,而絕對損失函數為節點樣本的中位數)
(2)對(M表示迭代次數,即生成的弱學習器個數):
(a)對樣本,計算損失函數的負梯度在當前模型的值將它作為殘差的估計,對於平方損失函數為,它就是通常所說的殘差;而對於一般損失函數,它就是殘差的近似值(偽殘差):
(b)對擬合一個回歸樹,得到第m棵樹的葉節點區域
,
(J表示每棵樹的葉節點個數)
(c)對,利用線性搜索,估計葉節點區域的值,使損失函數最小化,計算
(d)更新
(3)得到最終的回歸樹(即是每棵樹的葉節點值相加)
三. GBDT的算法描述
3.1 GBDT的損失函數
在sklearn中梯度提升回歸樹有四種可選的損失函數(注意一下是哪個參數),分別為'ls:平方損失','lad:絕對損失','huber:huber損失','quantile:分位數損失';而在sklearn中梯度提升分類樹有兩種可選的損失函數(分類對應的損失函數類別一般是指數函數),一種是‘exponential:指數損失’,一種是‘deviance:對數損失’。下面分別介紹這幾種損失函數。
3.1.1 梯度提升回歸樹損失函數介紹
(1)ls:平方損失,這是最常見的回歸損失函數了(負梯度就是殘差),如下:
(2)lad:絕對損失,這個損失函數也很常見,如下:
對應負梯度(有必要知道負梯度是什么東西了)為:
(3)huber:huber損失,它是平方損失和絕對損失的這種產物,對於遠離中心的異常點采用絕對損失,而中心附近的點采用平方損失。這個界限一般用分位數點度量。損失函數如下:
對應的負梯度為:
(4)quantile:分位數損失,它對應的是分位數回歸的損失函數,表達式如下:
其中θ為分位數,需要我們在回歸前指定。對應的負梯度為:
對於huber損失和分位數損失主要作用就是減少異常點對損失函數的影響。
3.1.2 梯度提升分類樹損失函數介紹
(1)exponential:指數損失,表達式如下:
(2)deviance:對數損失,類似於logistic回歸的損失函數,輸出的是類別的概率,表達式如下:
下面我們來分別的介紹一下,這幾種損失函數對應GBDT算法。
3.2 GBDT回歸算法描述
3.2.1 平方損失GBDT算法描述
輸入:訓練數據集,損失函數為
輸出:回歸樹
(1)初始化:(可以證明當損失函數為平方損失時,節點的平均值即為該節點中使損失函數達到最小值的最優預測值,證明在最下面的附錄給出)
(2)對:
(a)對樣本,計算偽殘差(對於平方損失來說,偽殘差就是真殘差)
,
(b)對擬合一個回歸樹,得到第m棵樹的葉節點區域
,
(c)對,利用線性搜索,估計葉節點區域的值,使損失函數最小化,計算
,K表示第m棵樹的第j個節點中的樣本數量(為什么要除以k,因為上面說了節點的平均值為該節點中最優預測值)
上式表示的取值為第m棵樹的第j個葉節點中偽殘差的平均數
(d)更新
(3)得到最終的回歸樹
3.2.2 絕對損失GBDT算法描述
輸入:訓練數據集,損失函數為
輸出:回歸樹
(1)初始化:(可以證明當損失函數為絕對損失時,節點中樣本的中位數即為該節點中使損失函數達到最小值的最優預測值,證明在最下面的附錄給出)
(2)對:
(a)對樣本,計算偽殘差(是一個sign函數)
,
(b)對擬合一個回歸樹,得到第m棵樹的葉節點區域
,
(c)對,
,計算
上式表示的取值為第m棵樹的第j個葉節點中偽殘差的中位數
(d)更新
(3)得到最終的回歸樹
3.2.3 huber損失GBDT算法描述
輸入:訓練數據集,損失函數為
輸出:回歸樹
(1)初始化:
(2)對:
(a)對樣本,計算
表示分位數;
表示將偽殘差的百分之多少設為分位數,在sklearn中是需要我們自己設置的,默認為0.9
(b)對擬合一個回歸樹,得到第m棵樹的葉節點區域
,
(c)對,
,計算
(d)更新
(3)得到最終的回歸樹
3.3 GBDT分類算法描述
GBDT分類算法思想上和GBDT的回歸算法沒有什么區別,但是由於樣本輸出不是連續值,而是離散類別,導致我們無法直接從輸出類別去擬合類別輸出誤差。為了解決這個問題,主要有兩種方法。一是用指數損失函數,此時GBDT算法退化為AdaBoost算法。另一種方法是用類似於邏輯回歸的對數似然損失函數的方法。也就是說,我們用的是類別的預測概率值和真實概率值的差來擬合損失。當損失函數為指數函數時,類似於AdaBoost算法,這里不做介紹,下面介紹損失函數為log(對數)函數時的GBDT二分類和多分類算法。
3.3.1 log損失GBDT的二分類算法描述
輸入:訓練數據集,損失函數為
,y={-1,1}
輸出:分類樹
(1)初始化:
(2)對:
(a)對樣本,計算偽殘差
(b)對概率殘差擬合一個分類樹,得到第m棵樹的葉節點區域
,
(c)對,
,計算
(d)更新
(3)得到最終的分類樹
由於我們用的是類別的預測概率值和真實概率值的差來擬合損失,所以最后還要講概率轉換為類別,如下:
最終輸出比較類別概率大小,概率大的就預測為該類別。
3.3.2 log損失GBDT的多分類算法描述
輸入:訓練數據集,損失函數為
,
={0,1}表示是否屬於第k類別,1表示是,0表示否。
,表示共有多少分類的類別。
輸出:分類樹
(1)初始化:
,
(2)對:
(a)計算樣本點俗屬於每個類別的概率:
(b)對k=1,2,...,K:
1) ,
2)對概率偽殘差擬合一個分類樹
3)
4)
(3)得到最終的分類樹
最后得到的可以被用來去得到分為第k類的相應的概率
:
由於我們用的是類別的預測概率值和真實概率值的差來擬合損失,所以最后還要將概率轉換為類別,如下:
為最終的輸出類別,
為當真實值為
時,預測為第k類時的聯合代價,即概率最大的類別即為我們所預測的類別。當K=2時,該算法等價於為二分類算法。
到這里,我們算法的描述環節已經介紹完畢。還有一個算法就是分位數回歸的算法描述沒有介紹,因為早期的論文里面並沒有介紹到該算法,所以,這里我們也不予以介紹,感興趣的小伙伴可以查閱相關資料或者直接看sklearn有關該算法的源碼。
最后,我們還有兩個證明沒有說,接下來我們證明我們在上面提到的有關損失函數為平方損失時葉節點的最佳預測為葉節點的殘差均值和損失函數為絕對損失時,葉節點的最佳預測為葉節點殘差的中位數。
四. 附錄
4.1 證明:損失函數為平方損失時,葉節點的最佳預測為葉節點殘差均值
節點R中有N個樣本點,假設s為切分點,,
分別為切分后的左節點和右節點,分別有節點個數為
。
我們的目標是找到切分點s,在,
內部使平方損失誤差達到最小值的
,如下:
和
分別對
求偏導,並令偏導等於0,得到在
,
內部使平方損失誤差達到最小值的
:
,
而和即為各自葉節點中的殘差的均值。
4.2 證明:損失函數為絕對損失時,葉節點的最佳預測為葉節點殘差的中位數。
損失函數
假設在節點中有個節點使
,則有
個節點使
,那么:
我們的目標是是損失函數最小化,所以,上式對求偏導,並令偏導等於0,得:
得:
而N為節點中樣本的總數,所以使節點的最佳預測為節點中殘差的中位數。
五. 參考文獻/博文
(2)《統計學習方法》第八章