對梯度下降算法的理解和實現


對梯度下降算法的理解和實現

​ 梯度下降算法是機器學習程序中非常常見的一種參數搜索算法。其他常用的參數搜索方法還有:牛頓法、坐標上升法等。

以線性回歸為背景

​ 當我們給定一組數據集合 \(D=\{(\mathbf{x^{(0)}},y^{(0)}),(\mathbf{x^{(1)}},y^{(1)}),...,(\mathbf{x^{(n)}},y^{(n)})\}\) ,其中上標為樣本標記,每個 \(\mathbf{x^{(i)}}\) 為一個 \(d\) 維向量(向量默認加粗表示)。我們在有了一定數量的樣本的情況下,希望能夠從樣本數據中提取信息或者某種模式,從而實現對新的數據也能具有一定的預測作用,這就需要我們找到一個能表示這組數據集合 \(D\) 的函數表達式。這樣我們就從離散的點得到了連續的函數曲線,從而可以預測未曾見過的輸入變量。

​ 一種常見的假設是,將輸入變量和輸出變量之間的關系假設為線性關系:$$h_\theta(\mathbf x) = \theta_0 + \theta_1x_1 + ... + \theta_kx_d= \mathbf{\theta x^T}$$ 。其中 \(h\) 為 hypothesis,是我們假設的能夠表示數據集合 \(D\) 的假設函數。而我們同時假設,存在一個 true function \(f\) ,使得樣本集合 \(D\) 中的樣本都是由該函數,加上一定的噪聲產生的(因為我們無法考慮到與響應變量 \(y\) 相關的所有的情況,也無法搜集到所有的數據)。並且很常見的,我們假設噪聲服從正態分布。機器學習的任務就是從假設函數空間 \(H=\{h_1,h_2,...,h_k\}\) 中找到一個對 true function 最好的近似。

​ 這個尋找 \(h\) 的過程,就是機器去學習的過程。

從直觀角度出發,我們可以設定這樣一個目標函數:希望有一條直線能夠距離每個樣本點的距離都十分的近,整體來看,希望距離所有樣本的距離和最近,這樣的一條直線是最有可能接近 true function 的。形式化表達,可以表示為:

也可稱之為損失函數。當該函數取得最小值時,說明當前的這個 \(h_\theta\) 是在當前數據集下對 true function 的最好的一個估計。

​ 所以問題轉化為一個最優化問題,在損失函數最小的情況下的 \(\mathbf{\theta}\) 是要求解的。這里我先舉一個直觀的例子。假設樣本集合 \(D = {(1,2),(2,2)}\)\(h_\theta(\mathbf x) = \theta_0 + \theta_1x\) ;在此數據集下求解參數 \(\mathbf{\theta}\) ,將數據帶入損失函數:

​ 以上討論,是希望能直觀化的闡明一點,當帶入所有的數據樣本后,損失函數變成了一個只和參數 \(\theta\) 相關的函數。而對於這個二元函數求極值,我相信大家都不會陌生。可以令各個變量的偏導同時為 0 來求解可能的極值點,然后在這些極值點中,尋找最小的點,即為損失函數最小值點,此時對應的 \(\theta_0,\theta_1\) 就是我們要求解的參數。帶回假設函數后,該函數就是我們對 true function 的一個近似。而預測過程就很簡單了,只要將新的輸入變量 \(x\) 帶入 \(h_\theta\) 即可得到響應變量 \(y\)

梯度下降算法

​ 闡述完背景,接下來討論梯度下降算法。你應該注意到了,之前對參數 \(\theta_0,\theta_1\) 的求解,我們是通過手動計算的方式計算出來的。事實上這個過程應當由計算機來完成,這很自然。那么一種顯而易見的方法是對手算過程用計算機模擬,即求同時使得損失函數各個偏導都為 0 的點,然后去確定所有的極小值、極大值點,然后得到最小值點。但是呢,這個計算過程,對於人去手算可能並不困難,但是對於計算機求解卻並不容易,因為這涉及到公式的推導,有時情況會很復雜,並且當損失函數形式變得復雜時,這也是不現實的。所以,一種非常簡單而又直覺化的方法被提出——梯度下降算法。

​ 梯度下降算法的直觀解釋是,在當前損失函數的某個點上,如果想要到達該函數的最低點,那么應該向下降速度最快的那個方向走一步,而這個方向,就是梯度的方向。步長采用對該方向分量的偏導值,也就是梯度的值。梯度下降算法的參數 \(\theta\) 更新公式為:

這里給出該公式的一個直觀解釋,以及它為什么可行。參考下圖:

​ 現在只考慮某個分量\(\theta_i\) 與函數 \(J\) 的關系。當初始化一個 \(\theta_i\) 為某個值,它將位於損失函數的某個點 P 上,然后在該點計算一個偏導:\(\frac{\partial J(\theta))}{\partial \theta_i}\) ,對應上圖中的深藍色箭頭,此時該偏導為負,所以按照 \(\theta\) 的更新公式:\(\theta_j = \theta_j - \frac{\partial J(\theta))}{\partial \theta_j}\) 可知 $\theta $ 將向坐標軸右方向移動,即更靠近函數的最低點。

​ 當進行了數次的迭代更新后,\(\theta\) 將不斷向損失函數的最低點靠近,而該點,正是 \(\frac{\partial J(\theta))}{\partial \theta_j} = 0\) 的點。此時 \(\theta\) 將會收斂。你會發現,這與我們手算偏導為 0 的點是相同的!而這個過程會在每個 \(\theta\) 的分量 \(\theta_j\) 上進行(相當於我們手算對所有的變元求偏導為 0)。結果如下圖:

此時便完成了對 \(\theta\) 的一個分量 \(\theta_j\) 的參數搜索。當偏導為正時,情況類似。

​ 和我們去手算損失函數的最小值不同,梯度下降算法去搜索最小點很容易陷入到局部的極小值中,最后收斂在這一點反而不能找到全局的最小值。解決這一問題的方法有很多,最常見的就是通過初始化不同的起點,以避免陷入局部極小值。另一種方法是通過合理調整學習率,通過使算法每步的步長大一些,從而跳過一些局部的“凹陷”極小值處(但是過大的學習率也會帶來問題,稍后我將展示這一點)。其實大多數,我們的目標函數都是單一凹凸性的,所以梯度下降算法一般可以工作的很好。

隨機梯度下降和批梯度下降

​ 為了得到梯度下降的具體公式,便於用計算機迭代求解,我們需要先做一些推導。我們已知損失函數:$$J(\theta) = \frac{1}{2} \sum_{i=1}^{n} (h_\theta(\mathbf x^{(i)}) - y{(i)})2$$ ,假設只有一個樣本時(對於所有樣本的情況,公式幾乎相同,只差一個求和符號),對某個 \(\theta\) 分量 \(\theta_j\) 求偏導:

所以 $\theta_j $ 的更新公式為(全部樣本集下):

​ 我們能發現,這個更新公式的形式很容易用計算機進行模擬。

​ 對於梯度下降算法的實現有很多變種,最常見的兩種策略就是隨機梯度下降批梯度下降

​ 批梯度下降的偽代碼為:

​ 隨機梯度下降的偽代碼為:

​ 其中 \(\alpha\) 為學習率,控制每次移動的步長。

​ 批梯度下降的優點是精確,損失函數的每個分量每次更新都會遍歷所有的樣本,計算偏導並進行一次更新,缺點是這樣每次計算量很大。隨機梯度下降每次使用一個樣本進行參數的更新,優點是速度快且有隨機性,缺點是每次只利用了一個樣本。

​ 對於二者之間折中的方法是隨機小批量梯度下降算法

隨機小批量梯度下降算法的實現

問題背景

​ 首先,假設問題的背景為預測橘子的售價。

​ 我們假設橘子的售價和橘子的進價、質量和新鮮程度成線性關系,並且存在一個 true function \(f\) 在根據這些 attribute 生成橘子的售價,於是假設 true function為:

\(f = 1.25 * buyinprice + 0.42 * quality + 0.33 * fresh\)

但是現實是我們無法對一個現象進行精准的建模,所以為了更好的近似現實情況,我們給 true function 添加一個噪聲項,來表示無法被模型捕獲的因素,並用這個函數來生成我們的樣本數據。所以該函數為:\(f = 1.25 * buyinprice + 0.42 * quality + 0.33 * fresh + noise\)

buyinprice = np.random.uniform(2,9,100)
quality = np.random.normal(6,1.5,100)
fresh = np.random.uniform(1,10,100)
noise = np.random.normal(0.85,0.15,100)

y = 1.25 * buyinprice + 0.42 * quality + 0.33 * fresh + noise

​ 生成數據如圖(共100組):

​ 我們可以先看一下數據 buyinprice 的分布和與 price 的直觀上的關系:

sns.regplot(x='buyinprice',y='price',data=data)

​ quality:

​ 因為 quality 對 price 的影響遠沒有 buyinprice 大,所以數據顯得比較分散。也就是 quality 與 price 的關系受到另一個維度 buyinprice 的擾動非常大。fresh 與此相似。

接下來考慮進行我們的機器學習程序的設計。

​ 假設線性回歸模型:\(h_\theta(\mathbf{x}) = \theta_1x_1 + \theta_2x_2 + \theta_3x_3\)

​ 那么現在我們要從數據集中去學習參數,從而得到我們假設的模型的表達式。

​ 現在使用隨機小批量梯度下降算法來進行參數搜索。

	theta = [0.1,0.1,0.1]                              # initialize theta
    last_theta = [-100000,-100000,-100000]
    alpha = 0.001                                      # learning rate

    while measure_close(theta,last_theta):              
        random_pick = np.random.uniform(1,100,30)       # a small batch sample
        last_theta = theta[:]                           # reserve and copy
        for j in range(3):                             # update every Θj
            theta[j] = theta[j] - alpha * par_der(random_pick,last_theta,j)

    print(theta)

​ 可以看到 theta 的搜索過程:

​ 經過數輪迭代后:

​ 最后參數收斂在 \(\theta_1 = 1.277,\theta_2 = 0.499,\theta_3 = 0.35\),而這與我們的 true function 的參數是較為接近的,可以認為隨機小批量梯度下降算法取得了效果。

​ 然后我們觀察 buyinprice 和price 的 true function 圖像與我們通過梯度下降算法擬合出的圖像:

​ 其中藍色直線為 true function ,而紅色直線為我們通過梯度下降算法擬合出的直線,可以看到二者十分的接近。

而在整個數據集上,考慮到 quality 和 fresh 因素,得到的模型對 price 的預測 predict_price 和實際的價格 price 之間關系:

​ 能看到,二者幾乎相等。所以可以認為在訓練數據集上,我們的模型表現的非常好。

超參數調整

​ 在編寫梯度下降算法進行參數搜索時,出現了一個很有意思的 bug。剛開始很多次,我的參數搜索結果都是這樣的:

\(\theta\) 變得越來越大,而且速度非常快,很快,我得到了這個結果:

​ 它的值已經超出了數據范圍。為什么會出現這個問題?我困擾了很久。直到到想起了超參數(hyper-parameters)。

​ 我這里有兩個超參數:learning rate = 0.05,measure close = 0.1。第一個控制步長,第二個控制收斂條件。measure_close 函數的代碼如下:

def measure_close(theta,last_theta):

    res = 0
    for i in range(3):
        res += abs(theta[i] - last_theta[i])

    if(res >= 0.1):                               # hyper parameters:0.1
        return True
    else: return False 

​ 我想一幅圖可以很好的說明我遇到的問題:

​ 過大的步長使得梯度下降算法跳過了最低點,並且 \(\theta\) 朝着 x 軸的兩側不斷擴張,最后趨向於無窮。

而此時,通過不斷的調節 learning rate,和 measure close 的值,我們也能搜索到不同的 \(\theta\) 結果,直到找到一個我們覺得滿意的參數為止,這就是機器學習中的超參數調整(調參)。

下圖是我將 learning rate 設置為 0.0012 時的到的參數:

合適的超參數將會得到擬合程度更好的模型。(不考慮泛化能力)

  1. 參考資料 CS229 note1
  2. markdown 在博客園始終這么丑 :<




作者:Skipper
出處: https://www.cnblogs.com/backwords/p/13701122.html
本博客中未標明轉載的文章歸作者 Skipper 和博客園共有,歡迎轉載,但未經作者同意必須保留此段聲明,且在文章頁面明顯位置給出原文連接,否則保留追究法律責任的權利。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM