機器學習之最小二乘法


最小二乘法是機器學習中的基礎知識點,一致對最小二乘法的理解不夠深入,今天就花點時間來深入理解和探討一下最小二乘法

最小二乘法,又稱最小平方法,基本公式通俗來講,二者先取個差值,在來個平方,最后搞一個和號上去,這就是最小二乘問題的思想,下面介紹下

最小二乘法

   我們以最簡單的一元線性模型來解釋最小二乘法。什么是一元線性模型呢? 監督學習中,如果預測的變量是離散的,我們稱其為分類(如決策樹,支持向量機等),如果預測的變量是連續的,我們稱其為回歸。回歸分析中,如果只包括一個自變量和一個因變量,且二者的關系可用一條直線近似表示,這種回歸分析稱為一元線性回歸分析。如果回歸分析中包括兩個或兩個以上的自變量,且因變量和自變量之間是線性關系,則稱為多元線性回歸分析。對於二維空間線性是一條直線;對於三維空間線性是一個平面,對於多維空間線性是一個超平面...

   對於一元線性回歸模型, 假設從總體中獲取了n組觀察值(X1,Y1),(X2,Y2), …,(Xn,Yn)。對於平面中的這n個點,可以使用無數條曲線來擬合。要求樣本回歸函數盡可能好地擬合這組值。綜合起來看,這條直線處於樣本數據的中心位置最合理。 選擇最佳擬合曲線的標准可以確定為:使總的擬合誤差(即總殘差)達到最小。有以下三個標准可以選擇:

        (1)用“殘差和最小”確定直線位置是一個途徑。但很快發現計算“殘差和”存在相互抵消的問題。
        (2)用“殘差絕對值和最小”確定直線位置也是一個途徑。但絕對值的計算比較麻煩。
        (3)最小二乘法的原則是以“殘差平方和最小”確定直線位置。用最小二乘法除了計算比較方便外,得到的估計量還具有優良特性。這種方法對異常值非常敏感。

  最常用的是普通最小二乘法( Ordinary  Least Square,OLS):所選擇的回歸模型應該使所有觀察值的殘差平方和達到最小。(Q為殘差平方和)- 即采用平方損失函數。

  樣本回歸模型:

                                     其中ei為樣本(Xi, Yi)的誤差

   平方損失函數:

                      

   則通過Q最小確定這條直線,即確定,以為變量,把它們看作是Q的函數,就變成了一個求極值的問題,可以通過求導數得到。求Q對兩個待估參數的偏導數:

                       

    根據數學知識我們知道,函數的極值點為偏導為0的點。

    解得:

                   

 

這就是最小二乘法的解法,就是求得平方損失函數的極值點。

 

最小二乘法分為線性和非線性兩種,線性最小二乘法很好解決,可以將公式(1)變換為矩陣方程(公式2),最后直接求解矩陣方程即可,不需要迭代,這種解被稱為“解析解”

 

(1)

(2)

 非線性最小二乘問題則不然,它要復雜得多,沒有辦法變換為矩陣方程形式,以至於它必須將問題化簡為每一步均為可以直接求解的子問題,整個求解過程是迭代的。

線性最小二乘問題與非線性最小二乘的關系,就是非線性最小二乘問題的求解過程。

1. 對原問題中的每一個函數fi(x)在x0處進行一階泰勒展開,因為一階泰勒展開屬於線性函數(公式3),於是通過這種手段,就可以將非線性最小二乘問題簡化為線性最小二乘問題來求解。

               (3)

2. 對得到的線性最小二乘問題,進行直接求解。這里面涉及到兩個矩陣,一個是雅克比矩陣(公式4),一個是赫森矩陣(公式5)。

                        (4)

(5)

3. 得到子問題的解析解xk+1之后(公式2),xk+1與xk之間便自然地建立了等式關系(公式6)。

(6)

4. 更新參數xk(k=k+1, k=1....n),回到步驟1,直到滿足收斂條件,得到最優解x*

 

沒錯,就是講非線性轉化為線性問題去解決,下面說名幾個注意點:

第一:步驟1中,一定要一階泰勒展開,不能采用二階以上,因為只有一階泰勒展開才是線性函數,才能轉換為線性最小二乘問題來直接求解。

第二:步驟2中,雅克比矩陣和赫森矩陣都是屬於子問題的,不是原問題的。

第三:步驟3中,是為了得到新求解的參數xk+1與之前參數xk之間的關系,形成一種“鏈式反應”,也就是迭代了。

第四:步驟4中,收斂條件一般有1.梯度近乎為0。2.變量變化很小。3.目標函數值變化很小等。

第五:許多優化算法,都可以用於解決非線性最小二乘問題。

第六:函數fi(x)往往都是如下形式(公式7),千萬別以為fi(x)就是hi(x)

 

(7)

 

解釋完了,一團亂麻很正常,我們致力於應用,能理解更好,實在理解不了就理解應用場景,畢竟現在都是面向場景式編程。

說白了,最小二乘法可以得到平方損失函數最小的點,也就是全局最小,通俗點就是擬合度比較好,所以我們一般都是用於擬合數據建立線性模型用於預測

下面給出線性最小二乘法的Java實現:

package org.yujoo.baas.base;

/** 
 * 最小二乘法 y=ax+b 
 *  
 * @author yu joo
 *  
 */  
public class Theleastsquaremethod {  
  
    private static double a;  
  
    private static double b;  
  
    private static int num;  
  
    /** 
     * 訓練 
     *  
     * @param x 
     * @param y 
     */  
    public static void train(double x[], double y[]) {  
        num = x.length < y.length ? x.length : y.length;  
        calCoefficientes(x,y);  
    }  
  
    /** 
     * a=(NΣxy-ΣxΣy)/(NΣx^2-(Σx)^2) 
     * b=y(平均)-a*x(平均) 
     * @param x 
     * @param y 
     * @return 
     */  
    public static void calCoefficientes (double x[],double y[]){  
        double xy=0.0,xT=0.0,yT=0.0,xS=0.0;  
        for(int i=0;i<num;i++){  
            xy+=x[i]*y[i];  
            xT+=x[i];  
            yT+=y[i];  
            xS+=Math.pow(x[i], 2.0);  
        }  
        a= (num*xy-xT*yT)/(num*xS-Math.pow(xT, 2.0));  
        b=yT/num-a*xT/num;  
    }  
  
    /** 
     * 預測 
     *  
     * @param xValue 
     * @return 
     */  
    public static double predict(double xValue) {  
        System.out.println("a="+a);  
        System.out.println("b="+b);  
        return a * xValue + b;  
    }  
  
    public static void main(String args[]) {  
        double[] x = { 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 } ;    
        double[] y = {23 , 44 , 32 , 56 , 33 , 34 , 55 , 65 , 45 , 55 } ;    
        Theleastsquaremethod.train(x, y);  
        System.out.println(Theleastsquaremethod.predict(10.0));  
    }  
  
}  

 當然如果你不想寫也可以使用Apache開源庫commons math,提供的功能更強大,

http://commons.apache.org/proper/commons-math/userguide/fitting.html

 

<dependency>  
          <groupId>org.apache.commons</groupId>  
            <artifactId>commons-math3</artifactId>  
            <version>3.5</version>  
 </dependency>  

 

private static void testLeastSquareMethodFromApache() {  
        final WeightedObservedPoints obs = new WeightedObservedPoints();  
        obs.add(-3, 4);  
        obs.add(-2, 2);  
        obs.add(-1, 3);  
        obs.add(0, 0);  
        obs.add(1, -1);  
        obs.add(2, -2);  
        obs.add(3, -5);  
  
        // Instantiate a third-degree polynomial fitter.  
        final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(3);  
  
        // Retrieve fitted parameters (coefficients of the polynomial function).  
        final double[] coeff = fitter.fit(obs.toList());  
        for (double c : coeff) {  
            System.out.println(c);  
        }  
    }

最小二乘法使用的前提條件是數據連續的而非離散,最常使用的場景就是回歸模型,在監督學習中,如果預測的變量是離散的,我們稱其為分類(如決策樹,支持向量機等),如果預測的變量是連續的,我們稱其為回歸。回歸分析中,如果只包括一個自變量和一個因變量,且二者的關系可用一條直線近似表示,這種回歸分析稱為一元線性回歸分析。如果回歸分析中包括兩個或兩個以上的自變量,且因變量和自變量之間是線性關系,則稱為多元線性回歸分析。對於二維空間線性是一條直線;對於三維空間線性是一個平面,對於多維空間線性是一個超平面。最小二乘法就是回歸問題解決的基本方法,同時,最小二乘法在數學上稱為曲線擬合。

 

參考1:最優化理論與算法

參考2:利用Levenberg_Marquardt算法求解無約束的非線性最小二乘問題~

參考3:利用信賴域算法求解無約束的非線性最小二乘問題~

參考4:http://blog.csdn.NET/wsj998689aa/article/details/41558945


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM