最小二乘法是機器學習中的基礎知識點,一致對最小二乘法的理解不夠深入,今天就花點時間來深入理解和探討一下最小二乘法
最小二乘法,又稱最小平方法,基本公式通俗來講,二者先取個差值,在來個平方,最后搞一個和號上去,這就是最小二乘問題的思想,下面介紹下
最小二乘法
我們以最簡單的一元線性模型來解釋最小二乘法。什么是一元線性模型呢? 監督學習中,如果預測的變量是離散的,我們稱其為分類(如決策樹,支持向量機等),如果預測的變量是連續的,我們稱其為回歸。回歸分析中,如果只包括一個自變量和一個因變量,且二者的關系可用一條直線近似表示,這種回歸分析稱為一元線性回歸分析。如果回歸分析中包括兩個或兩個以上的自變量,且因變量和自變量之間是線性關系,則稱為多元線性回歸分析。對於二維空間線性是一條直線;對於三維空間線性是一個平面,對於多維空間線性是一個超平面...
對於一元線性回歸模型, 假設從總體中獲取了n組觀察值(X1,Y1),(X2,Y2), …,(Xn,Yn)。對於平面中的這n個點,可以使用無數條曲線來擬合。要求樣本回歸函數盡可能好地擬合這組值。綜合起來看,這條直線處於樣本數據的中心位置最合理。 選擇最佳擬合曲線的標准可以確定為:使總的擬合誤差(即總殘差)達到最小。有以下三個標准可以選擇:
(1)用“殘差和最小”確定直線位置是一個途徑。但很快發現計算“殘差和”存在相互抵消的問題。
(2)用“殘差絕對值和最小”確定直線位置也是一個途徑。但絕對值的計算比較麻煩。
(3)最小二乘法的原則是以“殘差平方和最小”確定直線位置。用最小二乘法除了計算比較方便外,得到的估計量還具有優良特性。這種方法對異常值非常敏感。
最常用的是普通最小二乘法( Ordinary Least Square,OLS):所選擇的回歸模型應該使所有觀察值的殘差平方和達到最小。(Q為殘差平方和)- 即采用平方損失函數。
樣本回歸模型:
其中ei為樣本(Xi, Yi)的誤差
平方損失函數:

則通過Q最小確定這條直線,即確定
,以
為變量,把它們看作是Q的函數,就變成了一個求極值的問題,可以通過求導數得到。求Q對兩個待估參數的偏導數:
根據數學知識我們知道,函數的極值點為偏導為0的點。
解得:

這就是最小二乘法的解法,就是求得平方損失函數的極值點。
最小二乘法分為線性和非線性兩種,線性最小二乘法很好解決,可以將公式(1)變換為矩陣方程(公式2),最后直接求解矩陣方程即可,不需要迭代,這種解被稱為“解析解”
(1)
(2)
非線性最小二乘問題則不然,它要復雜得多,沒有辦法變換為矩陣方程形式,以至於它必須將問題化簡為每一步均為可以直接求解的子問題,整個求解過程是迭代的。
線性最小二乘問題與非線性最小二乘的關系,就是非線性最小二乘問題的求解過程。
1. 對原問題中的每一個函數fi(x)在x0處進行一階泰勒展開,因為一階泰勒展開屬於線性函數(公式3),於是通過這種手段,就可以將非線性最小二乘問題簡化為線性最小二乘問題來求解。
(3)
2. 對得到的線性最小二乘問題,進行直接求解。這里面涉及到兩個矩陣,一個是雅克比矩陣(公式4),一個是赫森矩陣(公式5)。
(4)
(5)
3. 得到子問題的解析解xk+1之后(公式2),xk+1與xk之間便自然地建立了等式關系(公式6)。
(6)
4. 更新參數xk(k=k+1, k=1....n),回到步驟1,直到滿足收斂條件,得到最優解x*
沒錯,就是講非線性轉化為線性問題去解決,下面說名幾個注意點:
第一:步驟1中,一定要一階泰勒展開,不能采用二階以上,因為只有一階泰勒展開才是線性函數,才能轉換為線性最小二乘問題來直接求解。
第二:步驟2中,雅克比矩陣和赫森矩陣都是屬於子問題的,不是原問題的。
第三:步驟3中,是為了得到新求解的參數xk+1與之前參數xk之間的關系,形成一種“鏈式反應”,也就是迭代了。
第四:步驟4中,收斂條件一般有1.梯度近乎為0。2.變量變化很小。3.目標函數值變化很小等。
第五:許多優化算法,都可以用於解決非線性最小二乘問題。
第六:函數fi(x)往往都是如下形式(公式7),千萬別以為fi(x)就是hi(x)
(7)
解釋完了,一團亂麻很正常,我們致力於應用,能理解更好,實在理解不了就理解應用場景,畢竟現在都是面向場景式編程。
說白了,最小二乘法可以得到平方損失函數最小的點,也就是全局最小,通俗點就是擬合度比較好,所以我們一般都是用於擬合數據建立線性模型用於預測
下面給出線性最小二乘法的Java實現:
package org.yujoo.baas.base; /** * 最小二乘法 y=ax+b * * @author yu joo * */ public class Theleastsquaremethod { private static double a; private static double b; private static int num; /** * 訓練 * * @param x * @param y */ public static void train(double x[], double y[]) { num = x.length < y.length ? x.length : y.length; calCoefficientes(x,y); } /** * a=(NΣxy-ΣxΣy)/(NΣx^2-(Σx)^2) * b=y(平均)-a*x(平均) * @param x * @param y * @return */ public static void calCoefficientes (double x[],double y[]){ double xy=0.0,xT=0.0,yT=0.0,xS=0.0; for(int i=0;i<num;i++){ xy+=x[i]*y[i]; xT+=x[i]; yT+=y[i]; xS+=Math.pow(x[i], 2.0); } a= (num*xy-xT*yT)/(num*xS-Math.pow(xT, 2.0)); b=yT/num-a*xT/num; } /** * 預測 * * @param xValue * @return */ public static double predict(double xValue) { System.out.println("a="+a); System.out.println("b="+b); return a * xValue + b; } public static void main(String args[]) { double[] x = { 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 } ; double[] y = {23 , 44 , 32 , 56 , 33 , 34 , 55 , 65 , 45 , 55 } ; Theleastsquaremethod.train(x, y); System.out.println(Theleastsquaremethod.predict(10.0)); } }
當然如果你不想寫也可以使用Apache開源庫commons math,提供的功能更強大,
http://commons.apache.org/proper/commons-math/userguide/fitting.html
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.5</version>
</dependency>
private static void testLeastSquareMethodFromApache() { final WeightedObservedPoints obs = new WeightedObservedPoints(); obs.add(-3, 4); obs.add(-2, 2); obs.add(-1, 3); obs.add(0, 0); obs.add(1, -1); obs.add(2, -2); obs.add(3, -5); // Instantiate a third-degree polynomial fitter. final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(3); // Retrieve fitted parameters (coefficients of the polynomial function). final double[] coeff = fitter.fit(obs.toList()); for (double c : coeff) { System.out.println(c); } }
最小二乘法使用的前提條件是數據連續的而非離散,最常使用的場景就是回歸模型,在監督學習中,如果預測的變量是離散的,我們稱其為分類(如決策樹,支持向量機等),如果預測的變量是連續的,我們稱其為回歸。回歸分析中,如果只包括一個自變量和一個因變量,且二者的關系可用一條直線近似表示,這種回歸分析稱為一元線性回歸分析。如果回歸分析中包括兩個或兩個以上的自變量,且因變量和自變量之間是線性關系,則稱為多元線性回歸分析。對於二維空間線性是一條直線;對於三維空間線性是一個平面,對於多維空間線性是一個超平面。最小二乘法就是回歸問題解決的基本方法,同時,最小二乘法在數學上稱為曲線擬合。
參考1:最優化理論與算法
參考2:利用Levenberg_Marquardt算法求解無約束的非線性最小二乘問題~
參考4:http://blog.csdn.NET/wsj998689aa/article/details/41558945
