坐標下降法(coordinate descent method)求解LASSO的推導


坐標下降法(coordinate descent method)求解LASSO推導

LASSO在尖點是singular的,因此傳統的梯度下降法、牛頓法等無法使用。常用的求解算法有最小角回歸法、coordinate descent method等。
由於coordinate descent method是相對較簡單的做法,放在第一個介紹。

坐標下降法思想

坐標下降法基於的思想很簡單,就是當面對最小化一個多元函數的問題時,我們每一次迭代的時候只改變一個目標變量的值。也就是固定其他變量不動,只在該變量的維度上尋找一個使函數最小的值。這種思想類似於貪心算法。

推導過程

定義Loss function為:

\[\frac{1}{N}\sum_{i=1}^{N}(y_i-x_i^T\cdot \beta) \]

其中,\(x_i\)是p·1維的向量,\(\beta\)是p·1維的向量。

Penalty為Lasso penalty:

\[\sum_{j=1}^p|\beta_j| \]

定義超參數為\(\lambda\)

目標函數為:

\[L=\frac{1}{N}\sum_{i=1}^{N}(y_i-x_i^T\cdot \beta+\lambda\sum_{j=1}^p|\beta_j|) \]

應用坐標下降法的思想,我們固定住\(x_k\ne x_j\)的變量,然后在每一輪迭代中只優化\(x_j\)

可以采用的迭代順序是從j=1依次到p進行迭代,然后再從j=1開始。

當固定住其他變量時,求object function的極小值就等價於求解一元LASSO的問題。

\[L=\frac{1}{N}\sum_{i=1}^{N}(r_i-\beta_jx_{ji})^2+\lambda \beta_j \tag{1} \]

其中,\(r_i=y_i-\sum_{k\ne j}x_{ik}\beta_k\),也就是只用其他變量擬合y的殘差。

將式1稍微化簡一下,可以得到:

\[L=\beta_j^2\frac{\sum_{i=1}^{N}x_{ji}^2}{N}-2\beta_j\frac{\sum_{i=1}^{N}r_ix_{ji}}{N}+\frac{\sum_{i=1}^{N}r_i^2}{N}+\lambda{|\beta_j|} \]

這是一個二次函數。由於涉及到絕對值,我們需要分兩個區間討論:\(\beta_j<0\)\(\beta_j>0\)

相當於我們將\(\beta_j\)的取值划成了兩個空間,分別討論極值。最后的極值是把這兩個空間的極值再取最小值。

  • 第一個區間, \(\beta_j>0\)
    可以觀察到object function是一個開口向上二次函數,全局最小點在\(\beta_j=\frac{2\frac{\sum r_ix_i}{N}-\lambda}{2\sum x_i^2}{N}\)處取得。
    但是我們這時的定義域限制在 \(\beta_j>0\),因此需要分類討論是否能取全局最小點:

\[if (2\frac{\sum r_ix_i}{N}-\lambda>0):\\ {\beta_j^{*}=\frac{2\frac{\sum r_ix_i}{N}-\lambda}{2\sum x_i^2}{N}}\\ Else:\\ {\beta_j^{*}=0} \]

  • 第二個區間,\(\beta_j<0\)
    全局最小點在\(\beta_j=\frac{2\frac{\sum r_ix_i}{N}+\lambda}{2\sum x_i^2}{N}\)處取得。

但是我們這時的定義域限制在 \(\beta_j<0\),因此需要分類討論是否能取全局最小點:

\[if (2\frac{\sum r_ix_i}{N}+\lambda<0):\\ {\beta_j^{*}=\frac{2\frac{\sum r_ix_i}{N}+\lambda}{2\sum x_i^2}{N}}\\ Else:\\ {\beta_j^{*}=0} \]

綜合上面的討論,

  • case1:\(2\frac{\sum r_ix_i}{N}<-\lambda\)
    \(\beta_j^{*}=\frac{2\frac{\sum r_ix_i}{N}+\lambda}{2\sum x_i^2}{N}\)

  • case2:\(-\lambda<2\frac{\sum r_ix_i}{N}<\lambda\)
    \(\beta_j^{*}=0\)

  • case3:\(\lambda<2\frac{\sum r_ix_i}{N}\)
    \(\beta_j^{*}=\frac{2\frac{\sum r_ix_i}{N}-\lambda}{2\sum x_i^2}{N}\)

定義一個軟閾值函數來統一三個case

\[\beta_j^{*}=\frac{\text{soft threshold}({2\frac{\sum r_ix_i}{N},\lambda)}}{2\frac{\sum x_i^2}{N}} \]

comment

對於用L2 loss function作為損失函數的回歸問題,由於object function是關於\(\beta\)的凸函數,因此我們一定可以找到一個全局最優點。迭代過程是收斂的。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM