回歸與梯度下降法及實現原理


回歸與梯度下降

回歸在數學上來說是給定一個點集,能夠用一條曲線去擬合之,如果這個曲線是一條直線,那就被稱為線性回歸,如果曲線是一條二次曲線,就被稱為二次回歸,回歸還有很多的變種,如locally weighted回歸,logistic回歸,等等,這個將在后面去講。

用一個很簡單的例子來說明回歸,這個例子來自很多的地方,也在很多的open source的軟件中看到,比如說weka。大概就是,做一個房屋價值的評估系統,一個房屋的價值來自很多地方,比如說面積、房間的數量(幾室幾廳)、地 段、朝向等等,這些影響房屋價值的變量被稱為特征(feature),feature在機器學習中是一個很重要的概念,有很多的論文專門探討這個東西。在 此處,為了簡單,假設我們的房屋就是一個變量影響的,就是房屋的面積。

假設有一個房屋銷售的數據如下:

面積(m^2)  銷售價錢(萬元)

123            250

150            320

87              160

102            220

…               …

這個表類似於帝都5環左右的房屋價錢,我們可以做出一個圖,x軸是房屋的面積。y軸是房屋的售價,如下:

 

 
 

如果來了一個新的面積,假設在銷售價錢的記錄中沒有的,我們怎么辦呢?

我們可以用一條曲線去盡量准的擬合這些數據,然后如果有新的輸入過來,我們可以在將曲線上這個點對應的值返回。如果用一條直線去擬合,可能是下面的樣子:

 

 
 

綠色的點就是我們想要預測的點。

首先給出一些概念和常用的符號,在不同的機器學習書籍中可能有一定的差別。

房屋銷售記錄表 - 訓練集(training set)或者訓練數據(training data), 是我們流程中的輸入數據,一般稱為x

房屋銷售價錢 - 輸出數據,一般稱為y

擬合的函數(或者稱為假設或者模型),一般寫做 y = h(x)

訓練數據的條目數(#training set), 一條訓練數據是由一對輸入數據和輸出數據組成的

輸入數據的維度(特征的個數,#features),n

下面是一個典型的機器學習的過程,首先給出一個輸入數據,我們的算法會通過一系列的過程得到一個估計的函數,這個函數有能力對沒有見過的新數據給出一個新的估計,也被稱為構建一個模型。就如同上面的線性回歸函數。

 

 
 

我們用X1,X2..Xn 去描述feature里面的分量,比如x1=房間的面積,x2=房間的朝向,等等,我們可以做出一個估計函數:

 

 
 

θ在這兒稱為參數,在這兒的意思是調整feature中每個分量的影響力,就是到底是房屋的面積更重要還是房屋的地段更重要。為了如果我們令X0 = 1,就可以用向量的方式來表示了:

 

 
 

我們程序也需要一個機制去評估我們θ是否比較好,所以說需要對我們做出的h函數進行評估,一般這個函數稱為損失函數(loss function)或者錯誤函數(error function),描述h函數不好的程度,在下面,我們稱這個函數為J函數

在這兒我們可以做出下面的一個錯誤函數:

 

 
 

這個錯誤估計函數是去對x(i)的估計值與真實值y(i)差的平方和作為錯誤估計函數,前面乘上的1/2是為了在求導的時候,這個系數就不見了。

如何調整θ以使得J(θ)取得最小值有很多方法,其中有最小二乘法(min square),是一種完全是數學描述的方法,在stanford機器學習開放課最后的部分會推導最小二乘法的公式的來源,這個來很多的機器學習和數學書 上都可以找到,這里就不提最小二乘法,而談談梯度下降法。

梯度下降法是按下面的流程進行的:

1)首先對θ賦值,這個值可以是隨機的,也可以讓θ是一個全零的向量。

2)改變θ的值,使得J(θ)按梯度下降的方向進行減少。

為了更清楚,給出下面的圖:

這是一個表示參數θ與誤差函數J(θ)的關系圖,紅色的部分是表示J(θ)有着比較高的取值,我們需要的是,能夠讓J(θ)的值盡量的低。也就是深藍色的部分。θ0,θ1表示θ向量的兩個維度。

 
 

在上面提到梯度下降法的第一步是給θ給一個初值,假設隨機給的初值是在圖上的十字點。

然后我們將θ按照梯度下降的方向進行調整,就會使得J(θ)往更低的方向進行變化,如圖所示,算法的結束將是在θ下降到無法繼續下降為止。

當然,可能梯度下降的最終點並非是全局最小點,可能是一個局部最小點,可能是下面的情況:

 
 

 

 
 

上面這張圖就是描述的一個局部最小點,這是我們重新選擇了一個初始點得到的,看來我們這個算法將會在很大的程度上被初始點的選擇影響而陷入局部最小點

下面我將用一個例子描述一下梯度減少的過程,對於我們的函數J(θ)求偏導J:(求導的過程如果不明白,可以溫習一下微積分)

 

 
 

下面是更新的過程,也就是θi會向着梯度最小的方向進行減少。θi表示更新之前的值,-后面的部分表示按梯度方向減少的量,α表示步長,也就是每次按照梯度減少的方向變化多少。

一個很重要的地方值得注意的是,梯度是有方向的,對於一個向量θ,每一維分量θi都可以求出一個梯度的方向,我們就可以找到一個整體的方向,在變化的時候,我們就朝着下降最多的方向進行變化就可以達到一個最小點,不管它是局部的還是全局的。

 
 

用更簡單的數學語言進行描述步驟2)是這樣的:

倒三角形表示梯度,按這種方式來表示,θi就不見了。

 
 

批量梯度下降法BGD

批量梯度下降法(Batch Gradient Descent,簡稱BGD)是梯度下降法最原始的形式,它的具體思路是在更新每一參數時都使用所有的樣本來進行更新,其數學形式如下:

  (1) 對上述的能量函數求偏導:

  (2) 由於是最小化風險函數,所以按照每個參數θθ的梯度負方向來更新每個θθ:

  具體的偽代碼形式為:

  repeat{    

      

        (for every j=0, ... , n)

  }

  從上面公式可以注意到,它得到的是一個全局最優解,但是每迭代一步,都要用到訓練集所有的數據,如果樣本數目mm很大,那么可想而知這種方法的迭代速度!所以,這就引入了另外一種方法,隨機梯度下降。

  優點:全局最優解;易於並行實現;

  缺點:當樣本數目很多時,訓練過程會很慢。

  從迭代的次數上來看,BGD迭代的次數相對較少。其迭代的收斂曲線示意圖可以表示如下:

隨機梯度下降法SGD

由於批量梯度下降法在更新每一個參數時,都需要所有的訓練樣本,所以訓練過程會隨着樣本數量的加大而變得異常的緩慢。隨機梯度下降法(Stochastic Gradient Descent,簡稱SGD)正是為了解決批量梯度下降法這一弊端而提出的。

  將上面的能量函數寫為如下形式:

  利用每個樣本的損失函數對θθ求偏導得到對應的梯度,來更新θθ:

  具體的偽代碼形式為:

  1. Randomly shuffle dataset;

  2. repeat{

    for i=1, ... , mm{

      

       (for j=0, ... , nn)

    }

  }

  隨機梯度下降是通過每個樣本來迭代更新一次,如果樣本量很大的情況(例如幾十萬),那么可能只用其中幾萬條或者幾千條的樣本,就已經將theta迭代到最優解了,對比上面的批量梯度下降,迭代一次需要用到十幾萬訓練樣本,一次迭代不可能最優,如果迭代10次的話就需要遍歷訓練樣本10次。但是,SGD伴隨的一個問題是噪音較BGD要多,使得SGD並不是每次迭代都向着整體最優化方向。

  優點:訓練速度快;

  缺點:准確度下降,並不是全局最優;不易於並行實現。

  從迭代的次數上來看,SGD迭代的次數較多,在解空間的搜索過程看起來很盲目。其迭代的收斂曲線示意圖可以表示如下:

小批量梯度下降法MBGD

有上述的兩種梯度下降法可以看出,其各自均有優缺點,那么能不能在兩種方法的性能之間取得一個折衷呢?即,算法的訓練過程比較快,而且也要保證最終參數訓練的准確率,而這正是小批量梯度下降法(Mini-batch Gradient Descent,簡稱MBGD)的初衷。

  MBGD在每次更新參數時使用b個樣本(b一般為10),其具體的偽代碼形式為:

  Say b=10, m=1000.

  Repeat{

    for i=1, 11, 21, 31, ... , 991{

    

    (for every j=0, ... , nn)

    }

  }

 

舉個例子:

假設我們已知門店銷量為

 

門店數X

實際銷量Y

1

13

2

14

3

20

4

21

5

25

6

30

我們如何預測門店數X與Y的關系式呢?假設我們設定為線性:Y=a0+a1X

 

接下來我們如何使用已知數據預測參數a0和a1呢?這里就是用了梯度下降法:


左側就是梯度下降法的核心內容,右側第一個公式為假設函數,第二個公式為損失函數。

其中 表示假設函數的系數,為學習率。

對我們之前的線性回歸問題運用梯度下降法,關鍵在於求出代價函數的導數,即:


直觀的表示,如下:

python代碼實現:

import sys  
#Training data set  
#each element in x represents (x1)  
x = [1,2,3,4,5,6]  
#y[i] is the output of y = theta0+ theta1 * x[1]  
y = [13,14,20,21,25,30]  
#設置允許誤差值  
epsilon = 1  
#學習率  
alpha = 0.01  
diff = [0,0]  
max_itor = 20  
error1 = 0  
error0 =0  
cnt = 0  
m = len(x)  
#init the parameters to zero  
theta0 = 0  
theta1 = 0  
while 1:  
    cnt=cnt+1  
    diff = [0,0]  
    for i in range(m):  
        diff[0]+=theta0+ theta1 * x[i]-y[i]  
        diff[1]+=(theta0+theta1*x[i]-y[i])*x[i]  
    theta0=theta0-alpha/m*diff[0]  
    theta1=theta1-alpha/m*diff[1]  
    error1=0  
    for i in range(m):  
        error1+=(theta0+theta1*x[i]-y[i])**2  
    if abs(error1-error0)< epsilon:  
        break  
    print('theta0 :%f,theta1 :%f,error:%f'%(theta0,theta1,error1))  
    if cnt>20:  
        print ('cnt>20')  
        break  
print('theta0 :%f,theta1 :%f,error:%f'%(theta0,theta1,error1))

結果如下:

theta0 :0.205000,theta1 :0.816667,error:1948.212261
theta0 :0.379367,theta1 :1.502297,error:1395.602361
theta0 :0.527993,theta1 :2.077838,error:1005.467313
theta0 :0.654988,theta1 :2.560886,error:730.017909
theta0 :0.763807,theta1 :2.966227,error:535.521394
theta0 :0.857351,theta1 :3.306283,error:398.166976
theta0 :0.938058,theta1 :3.591489,error:301.147437
theta0 :1.007975,theta1 :3.830615,error:232.599138
theta0 :1.068824,theta1 :4.031026,error:184.147948
theta0 :1.122050,theta1 :4.198911,error:149.882851
theta0 :1.168868,theta1 :4.339471,error:125.631467
theta0 :1.210297,theta1 :4.457074,error:108.448654
theta0 :1.247197,theta1 :4.555391,error:96.255537
theta0 :1.280286,theta1 :4.637505,error:87.584709
theta0 :1.310171,theta1 :4.706007,error:81.400378
theta0 :1.337359,theta1 :4.763073,error:76.971413
theta0 :1.362278,theta1 :4.810533,error:73.781731
theta0 :1.385286,theta1 :4.849922,error:71.467048
theta0 :1.406686,theta1 :4.882532,error:69.770228
theta0 :1.426731,theta1 :4.909448,error:68.509764
theta0 :1.445633,theta1 :4.931579,error:67.557539
cnt>20
theta0 :1.445633,theta1 :4.931579,error:67.557539
[Finished in 0.2s]

可以看到學習率在0.01時,error會正常下降。圖形如下:(第一張圖是學習率小的時候,第二張圖就是學習率較大的時候)

 



所以我們再調整一下新的學習率看看是否能看到第二張圖:

我們將學習率調整成了0.3的時候得到以下結果:

theta0 :6.150000,theta1 :24.500000,error:38386.135000  
theta0 :-15.270000,theta1 :-68.932500,error:552053.226569  
theta0 :67.840125,theta1 :285.243875,error:7950988.401277  
theta0 :-245.867981,theta1 :-1059.347887,error:114525223.507401  
theta0 :946.357695,theta1 :4043.346381,error:1649619133.261223  
theta0 :-3576.913313,theta1 :-15323.055232,error:23761091159.680252  
theta0 :13591.518674,theta1 :58177.105053,error:342254436006.869995  
theta0 :-51565.747234,theta1 :-220775.317546,error:4929828278909.234375  
theta0 :195724.210360,theta1 :837920.911885,error:71009180027939.656250  
theta0 :-742803.860227,theta1 :-3180105.158068,error:1022815271242165.875000  
theta0 :2819153.863813,theta1 :12069341.864380,error:14732617369683060.000000  
theta0 :-10699395.102930,theta1 :-45806250.675551,error:212208421856953728.000000  
theta0 :40606992.787278,theta1 :173846579.256281,error:3056647245837464576.000000  
theta0 :-154114007.118001,theta1 :-659792674.286440,error:44027905696333684736.000000  
theta0 :584902509.168162,theta1 :2504083725.690765,error:634177359734604038144.000000  
theta0 :-2219856149.407590,theta1 :-9503644836.328783,error:9134682134868024885248.000000  
theta0 :8424927779.709908,theta1 :36068788150.345154,error:131575838248146814631936.000000  
theta0 :-31974778105.915466,theta1 :-136890372077.920685,error:1895216599231190653730816.000000  
theta0 :121352546013.825867,theta1 :519534337912.329712,error:27298674329760760684609536.000000  
theta0 :-460564272592.117981,theta1 :-1971767072878.787598,error:393209736799816196514906112.000000  
theta0 :1747960435714.394287,theta1 :7483365594965.919922,error:5663787744653302294061776896.000000  
cnt>20  
theta0 :1747960435714.394287,theta1 :7483365594965.919922,error:5663787744653302294061776896.000000  
[Finished in 0.2s]

可以看到theta0和theta1都在跳躍,與預期相符。

上文使用的是批量梯度下降法,如遇到大型數據集的時候這種算法非常緩慢,因為每次迭代都需要學習全部數據集,后續推出了隨機梯度下降,其實也就是抽樣學習的概念。

參考文獻


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM