「05」回歸的誘惑:一文讀懂線性回歸


前言

從這一篇文章開始,就正式進入「美團」算法工程師帶你入門機器學習系列的正文了,之前的幾篇算是導讀和預熱,想必大家看的並不過癮。從這里開始,我們將會以線性回歸為起點,貫通回歸方法在機器學習算法中所扮演的角色、具有的功能和使用的方法。

說起回歸,它是我們在高中時就接觸過的內容。具體的,回歸(Regression)是指研究一組隨機變量(Y1 ,Y2 ,…,Yi)和另一組隨機變量(X1,X2,…,Xk)之間關系的統計分析方法,又稱多重回歸分析。通常Y1,Y2,…,Yi是因變量,X1、X2,…,Xk是自變量。因變量,就是指被影響、決定的變量,本身不參與運算,而自變量則是指自身發生變化、改變並參與運算,最終影響因變量的變量。這些內容都是高中學習過的基礎,這里僅僅做個回顧,不深入復習。

現在,讓我們先拋開機器學習、算法、模型這類名詞,從最簡單的線性回歸來看看,到底什么是回歸(的誘惑)

 

 

線性回歸是什么?

我們前面提到過,回歸是計算因變量和自變量之間統計關系的一種方法。而線性回歸可以理解為學習變量之間線性關系的方法。作為一切回歸的基礎,它已經存在了時間長了,是無數教科書的主題。

雖然看起來,線性回歸與一些更現代的統計學習方法,比如支持向量機相比,有些過於簡單。但在我們后續章節介紹的方法中,線性回歸仍是一種非常有用的統計學習方法,它可以被用來簡單的測試數據和分析數據。 基於這個出發點,許多有趣的機器學習方法可以被看作是概線性回歸的擴展。在談及具體的回歸方法之前,讓我們先來看看線性回歸為什么叫回歸。

 

在19世紀的英國,有一位著名的生物學家高爾頓,在研究父母和孩子身高的遺傳關系時,發現了一個直線方程,通過這個方程,他幾乎准確地擬合了被調查父母的平均身高 x 和 子女平均身高 y 之前的關系:

這在當時可是一件不得了的事情,那這個方程是什么意思呢?它代表父母身高每增加1個單位, 其成年子女的平均身高只增加0.516個單位,反映了一種“衰退”效應(“回歸”到正常人平均身高)。雖然之后的x與 y變量之間並不總是具有“衰退”(回歸)關系,但是為了紀念高爾頓這位偉大的統計學家,“線性回歸”這一名稱就保留了下來。

 

一元線性回歸

回歸最簡單的情形是一元線性回歸,也即由彼此間存在線性關系的一個自變量和一個因變量組成,方程寫作Y=a+bX+ε(X是自變量,Y是因變量,ε是隨機誤差)我們先來看一個例子:

假設有如下數據點,橫軸代表一個產品的廣告費X,縱軸代表產品的銷售額Y。回歸可以看做是用Y=a+bX+ε這條直線去擬合這些數據點,也就是盡量使這些數據點與直線的距離之和(也叫作平方誤差和)最小。

通過最小二乘法或梯度下降(后面會講這兩個方法),我們得到如下方程:

那么,當一個新的廣告計划出現,我們通過已知的X(廣告投入)代入方程,就可以盡可能准確的算出預期的產品銷售額了~當然,反過來,使用一個預期的產品銷售額,也可以反推出我們需要投入的廣告費用。

 

多元線性回歸

當自變量大於1的時候,比如X=(x1, x2),我們稱它為多元線性回歸,寫作

其中,y(x)也就是我們說的因變量,x為自變量,但是0次項和1次項的系數(a, b)被一個向量w所代替。這里的w是一個簡單的矩陣線性乘法問題,對應了以下的向量

如果對於線性代數問題還有不理解的小伙伴,可以自行查閱《線性代數同濟版》

來看一個例子:給定⼀個有關房屋的數據集,其中每棟房屋的相關數據包括⾯積(平⽅⽶)、房齡(年)和價格(元)。假設我們想使⽤任意⼀棟房屋的⾯積(設x1)和房齡(設x2)來估算它的真實價格(設y)。那么x1 和x2 即每棟房屋的特征(feature),y 為標簽(label)或真實值(ground truth)。在線性回歸模型中,房屋估計價格(設​)的表達式為

其中w1,w2 是權重(weight),通常用向量

來表示,b 是偏差(bias),也就是前面一元回歸里我們用到的b。這⾥的權重和偏差都叫做線性回歸模型的參數(parameter)

 

 

線性回歸的假設

線性回歸作為被嚴謹證明過的數學方法,有7個必備的假設前提。理論上,必須滿足這7個嚴格的假設,我們才能確保線性回歸學習到的公式/方程是統計意義成立的。但在實際使用時,我們只需要滿足前3個最重要的假設即可(其他的一般都默認成立)。以后在學習其他算法前,我們也需要了解類似的假設。

關於線性回歸,最主要的3條假設如下

  • 隨機誤差的均值為0
  • 隨機誤差的方差為σ^2
  • σ^2與X的值無關

這里規定的σ^2並不是一個具體數值,只需要大於0即可。主要是為了說明隨機誤差的方差是存在的,方差(二階中心矩)不存在,比如無限大,則這個模型就是病態的,這里涉及到高等代數,感興趣的同學可以自行深入。

 

若進一步假定隨機誤差遵從正態分布,就叫做正態線性模型。若有k個自變量和1個因變量,則因變量的值分為兩部分:一部分由自變量影響,即表示為它的函數,函數形式已知且含有未知參數;另一部分由其他的未考慮因素和隨機性影響,即隨機誤差。

一般來說,隨機誤差在參數學習中起到的作用有限,但我們在真正使用模型時還是需要先看看數據是否滿足線性回歸的前提,否則容易對模型參數產生擬合問題。

 

線性回歸的本質

一般來說,回歸分析是通過規定因變量和自變量來確定變量之間的因果關系,建立回歸模型,並根據實測數據來求解模型的各個參數,然后評價回歸模型是否能夠很好的擬合實測數據,如果能夠很好的擬合,則可以根據自變量作進一步預測,比如我們提到的廣告費用與產品銷售額的關系。

當函數為參數未知的線性函數時,稱為線性回歸分析模型;當函數為參數未知的非線性函數時,稱為非線性回歸分析模型。當自變量個數大於1時稱為多元回歸,當因變量個數大於1時稱為多元回歸

當X和Y只有一個維度(一元回歸),且因變量和自變量的關系是線性關系,線性圖表示就是一條直線,而多維度(多元回歸)學習到的參數方程,體現到空間中就是一個超平面。

 

要注意的是,我們平時可能會把擬合與回歸弄混淆,但其實二者有本質區別。你可以把現實世界的數據看做“表象”,把你擬合出來的那個模型看做“本質”。由表象到本質的過程就是“回歸”。而擬合是一種得到函數的手段,常和數值領域的“插值”放在一起,也就是得到回歸函數的手段。

當回歸函數未知時,我們可以通過擬合這種手段算出回歸函數,求這個回歸函數的問題叫做回歸問題。一個是問題的類別,一個是解決方法的類別,回歸和擬合的差別就在這里。

數學理論的世界是精確的,譬如在廣告-銷量方程中,你代入x=0就能得到唯一的 y=7.1884,但這個y並不是我們真實觀測到的,而是估計值。現實世界中的數據就像散點圖,我們只能盡可能地在雜亂中尋找規律,很難100%的完美擬合一條直線出來。用數學的模型去擬合現實的數據,這就是統計。統計不像數學那么精確,統計的世界不是非黑即白的,它有“灰色地帶”,但是統計會將理論與實際間的差別表示出來,也就是“誤差”。

我們在前面學習到的公式,就是線性回歸作為一種學習算法的本質,即模型

通過求解參數w,我們知道了自變量和因變量之間的線性關系,即我們擬合的直線

這個直線就是我們學習到的模型,盡可能地學習到一個完美的W,這就是線性回歸的本質和作用,也是一切機器學習學習算法的本質——函數(參數)學習。

 

 

線性回歸的參數學習

現在我們來看一看如何學習到這個完美的W。線性回歸的目標可以理解為減少殘差平方和,回到總體均值。在探究線性回歸的學習方法之前,我們先定義如下表示,方便后續解釋:

 ​表示輸入變量(自變量),第一部分例子中的X。

 ​表示輸出變量(因變量),第一部分例子中的Y。

一對​表示一組訓練樣本。

m個訓練樣本​稱為訓練集。

 

回到上面一元回歸的例子來,既然是用直線擬合散點,為什么最終得到的直線是y = 0.0512x + 7.1884,而不是下圖中的橙色的y = 0.0624x + 5呢?畢竟這兩條線看起來都可以擬合這些數據。

我們很容發現,數據不是真的落在一條直線上,而是分布在直線周圍,所以我們要找到一個評判標准,用於評價哪條直線才是最“合適”的。這就是我們以后時常會見到的東西——損失函數

 

在這里,我們使用的損失函數叫做殘差,也就是真實值和預測值間的差值(也可以理解為距離),用公式表示是:

對於某個廣告投入​ ,我們有對應的實際銷售量​和預測出來的銷售量​(通過將​代入直線公式計算得到),計算 ​ 的值,再將其平方(為了消除負號),將所有的 ​相加,就能量化出擬合的直線和實際之間的誤差。

這里使用的均方誤差有非常好的幾何意義,它對應了常用的歐幾里得距離或簡稱"歐氏距離" (Euclidean distance),也就是圖里的數據點和直線之間的距離。基於均方誤差最小化來進行模型求解的方法,稱為“最小二乘法” (least square method)。在線性回歸中,最小二乘法就是試圖找到一條直線,使所有樣本到直線上的歐氏距離之和最小。

 

最小二乘法

求解方程參數,使

最小化的過程,稱為線性回歸模型的最小二乘"參數估計" (parameter estimation)。我們可將E(w,b)分別對W 和b求導,得到

我們令上面兩個式子的導數為零可得到W和b最優解的閉式(closed-form) 解(也就是可以直接通過公式代入算出來的解)

其中​,為X的均值

 

當X和Y為多元回歸時,我們也有多元情況下的最小二乘法,寫作

同樣地,另這個偏導數為0,我們可以得到

這個東西也叫作正規方程,因為它很正規 (。≖ˇェˇ≖。)。具體的推導和運算涉及到矩陣的逆/偽逆,比單變量情形要復雜,這里不深入展開,感興趣的同學可以自行翻閱《矩陣論》。這里我們只需要知道兩點:

1. 現實任務中XTX 往往不是滿秩矩陣.例如在許多任務中我們會遇到大量的變量,其數目甚至超過樣例數,導致X 的列數多於行數,XTX 顯然不滿秩。此時可得到出無限個解, 它們都能使均方誤差最小化

2. 我們的輸入數據X(自變量),可以寫作矩陣形式,矩陣的橫軸代表每個數據的維度(比如房屋的價格,位置,年齡),縱軸代表每個房屋

學過線性代數的同學應該知道,當行、列很多時, 這個矩陣的任何運算都需要很大的計算量。尤其是輸入變量的維度較大時(橫軸的n比較大),該算法的計算復雜度成指數級增加。

因此,正規方程的解法在真實場景中很少見,我們有另一種叫做梯度下降的方法,通過損失一定精度,來近似逼近這個最優解。對於梯度方法,這里只做一個簡單介紹,之后會有專門的一期文章來聊聊機器學習中的優化方法。

 

梯度下降

這里引用優化教材中的一張圖,這里我們把參數W寫作​,兩者其實是一種東西。現在我們來思考:既然代價函數是關於​的函數,有沒有辦法把求解過程加速或者拆解呢?

答案是有。

上圖中的藍色區域可以理解為誤差函數最小的點,也就是我們要找的參數值,因此,找到該點對應的​,即完成了任務。如何找到最低點位置對應的參數呢?答案是對代價函數(也就是我們的誤差)求偏導數

我們用大學學過的微積分方法做一個拆解,可以得到:

這就是關於變量的偏導數。要注意的是,這里的h其實就是我們的y

假設我們的函數只有兩個維度(二元回歸),給定

就是我們要求的參數,誤差函數對第一個元求偏導的結果:

誤差函數對第二個元求偏導的結果:

求得的結果怎么使用?我們對 ​ 求偏導數的意義是得到這一點上的切線的斜率,它將給我們一個向最小值移動的方向。因此,​減去偏導數,就等於​向最小值的方向移動了一步。這一步的大小由一個參數決定,也稱作學習率。用公式表達如下:

這就是機器學習中大名鼎鼎的的梯度下降。這個公式為什么這么寫,有什么意義,之后在優化方法的文章中會寫。對底層原理感興趣的話,大家可以以前去看看MIT的微積分公開課(可汗學院、網易都有),以及Boyd所寫的《凸優化》,到時候看博客就會非常通透。

這里舉一個我在知乎上看到的例子,非常具體的解釋了線性回歸求解的過程

  1. 初始化一個模型,例如 h = 2 + 3x,也就是說,我們的初始參數是 
  2. 給定一個樣本對,例如(2,4),代入模型中求得預測值,即 h = 2 + 3*2 = 8
  3. 代入代價函數公式中,求代價值,即 J = 1/2 * (8-4) ^ 2 = 8
  4. 代入偏導數公式中求兩個變量的偏導數,即 

假設我們的學習率是0.1,那么代入梯度下降公式得到 

我們得到了新的參數,即 

所以新的模型是:h = 1.6 + 2.2x,新的預測值是h = 1.6 + 2.2*2 = 6,再次計算代價函數的值:J = 1/2 * (6-4) ^ 2 = 2

比較新的模型得到的代價值2,比老模型得到的代價值8減少了6,代價越小說明我們的模型與訓練集匹配的越好,所以通過不斷的梯度下降,我們可以得到最適合訓練數據的模型h,也就是前面提到的那條直線方程。

 

線性回歸的局限

線性回歸簡單、直觀、迅速,但也有不少局限,這也是之后更多高級算法的出現原因,它們一定程度上解決了線性回歸無法解決的問題。線性回歸的局限可以歸納以下幾點:

  • 需要嚴格的假設。
  • 只能用於變量間存在簡單線性關系的數據。
  • 當數據量、數據維度大時,計算量會指數級增加。
  • 需處理異常值,對異常值很敏感,對輸入數據差異也很敏感。
  • 線性回歸存在共線性,自相關,異方差等問題。

 

結語

到這里,線性回歸的文章就告一段落了。在這一篇文章中,我們通過線性回歸,簡單了解了機器學習的方式、概念和方法,但是對於更加具體的定義,比如模型、損失函數和監督學習還沒講到,這將是我之后文章的主題。下一期文章,我們將基於線性回歸,來深入探討回歸的更多使用方法。

線性回歸的Python代碼和案例實戰在這一篇:「06」回歸的誘惑:一文讀懂線性回歸(Python實戰篇) ,代碼不多,建議大家可以自己敲一敲。

PS. 如果大家閱讀其中的數學部分有些吃力的話,可以到我的這篇文章中找對應的知識點復習:「04」機器學習、深度學習需要哪些數學知識?

 

 

課后習題

給定每月電話咨詢次數(X)和每月實際銷量,線性回歸是否可以把圖中的數據點分為不同的兩個部分?如果可以,應該怎么分?如果不可以,又是為什么?

參考文獻

  1. 《機器學習》周志華
  2. 《動手學深度學習》MXNet Community
  3.  An Introduction toStatistical Learning with Applications in R
  4.  知乎:機器學習之線性回歸
  5.  知乎:線性回歸詳解


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM