spark-MLlib之線性回歸


>>提君博客原創  http://www.cnblogs.com/tijun/  <<

假定線性擬合方程: 

提君博客原創

變量 X是 i 個變量或者說屬性 

參數 ai 是模型訓練的目的就是計算出這些參數的值。 

線性回歸分析的整個過程可以簡單描述為如下三個步驟:

  1. 尋找合適的預測函數,即上文中的 h(x)h(x) ,用來預測輸入數據的判斷結果。這個過程時非常關鍵的,需要對數據有一定的了解或分析,知道或者猜測預測函數的“大概”形式,比如是線性函數還是非線性函數,若是非線性的則無法用線性回歸來得出高質量的結果。
  2. 構造一個Loss函數(損失函數),該函數表示預測的輸出(h)與訓練數據標簽之間的偏差,可以是二者之間的差(h-y)或者是其他的形式(如平方差開方)。綜合考慮所有訓練數據的“損失”,將Loss求和或者求平均,記為 J(θ)J(θ) 函數,表示所有訓練數據預測值與實際類別的偏差。
  3. 顯然, J(θ)J(θ) 函數的值越小表示預測函數越准確(即h函數越准確),所以這一步需要做的是找到 J(θ)J(θ) 函數的最小值。找函數的最小值有不同的方法,Spark中采用的是梯度下降法(stochastic gradient descent, SGD)。

線性回歸同樣可以采用正則化手段,其主要目的就是防止過擬合。

當采用L1正則化時,則變成了Lasso Regresion;當采用L2正則化時,則變成了Ridge Regression;線性回歸未采用正則化手段。通常來說,在訓練模型時是建議采用正則化手段的,特別是在訓練數據的量特別少的時候,若不采用正則化手段,過擬合現象會非常嚴重。L2正則化相比L1而言會更容易收斂(迭代次數少),但L1可以解決訓練數據量小於維度的問題(也就是n元一次方程只有不到n個表達式,這種情況下是多解或無窮解的)。

提君博客原創

在spark中分三種回歸:LinearRegression、Lasso和RidgeRegression(嶺回歸)

采用L1正則化時為Lasso回歸(元素絕對值),采用L2時為RidgeRegression回歸(元素平方),沒有正則化時就是線性回歸。

比如嶺回歸的損失函數: 

顯然,損失函數值越小說明當前這條直線擬合效果越好>>提君博客原創  http://www.cnblogs.com/tijun/  <<
通常用梯度下降法 用來最小化損失值? 

spark中線性回歸算法可使用的類包括LinearRegression、LassoWithSGD、RidgeRegressionWithSGD(SGD代表隨機梯度下降法),

這幾個類都有幾個可以用來對算法調優的參數

  • numIterations 要迭代的次數
  • stepSize 梯度下降的步長(默認1.0)
  • intercept 是否給數據加上一個干擾特征或者偏差特征(默認:false)
  • regParam Lasso和ridge的正規參數(默認1.0)

下面是實例>>提君博客原創  http://www.cnblogs.com/tijun/  <<

訓練集下載

訓練集概況

-0.4307829,-1.63735562648104 -2.00621178480549 -1.86242597251066 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
-0.1625189,-1.98898046126935 -0.722008756122123 -0.787896192088153 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
-0.1625189,-1.57881887548545 -2.1887840293994 1.36116336875686 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.155348103855541
...

 

數據格式:逗號之前為label;之后為8個特征值,以空格分隔。

代碼

package com.ltt.spark.ml.example;

import org.apache.spark.api.java.*;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.GeneralizedLinearModel;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LassoModel;
import org.apache.spark.mllib.regression.LassoWithSGD;
import org.apache.spark.mllib.regression.LinearRegressionModel;
import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
import org.apache.spark.mllib.regression.RidgeRegressionModel;
import org.apache.spark.mllib.regression.RidgeRegressionWithSGD;

import java.util.Arrays;

import org.apache.spark.SparkConf;
import scala.Tuple2;

/**
 * 
 * Title: LinearRegresionExample.java    
 * Description: 本地代碼執行,機器學習之線性回歸 
 * <br/>
 * @author liutiti
 * @created 2017年11月21日 下午4:03:45 
 */
@SuppressWarnings("resource")
public class LinearRegresionExample {

    /**
     * 
     * @discription 程序測試入口
     * @author liutiti       
     * @created 2017年11月21日 上午4:03:45  
     * @param args
     */
    public static void main(String[] args) {
        SparkConf sparkConf = new SparkConf().setAppName("LinearRegresion").setMaster("local[*]");
        JavaSparkContext sc =  new JavaSparkContext(sparkConf);
        //原始的數據-0.4307829,-1.63735562648104 -2.00621178480549 ...
        JavaRDD<String>  data = sc.textFile("E:\\spark-ml-data\\lpsa.txt");

        //轉換數據格式:把每一行原始的數據(num1,num2 num3 ...)轉換成LabeledPoint(label, features)
        JavaRDD<LabeledPoint> parsedData = data.filter(line -> {   //過濾掉不符合的數據行
                        if(line.length() > 3)
                            return true;
                        return false;
                    }).map(line -> {   //讀取轉換成LabeledPoint
                        String[] parts = line.split(",");  //逗號分隔
                        double[] ds = Arrays.stream(parts[1].split(" "))  //空格分隔
                              .mapToDouble(Double::parseDouble)
                              .toArray();
                        return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(ds));                
                    });
        //rdd持久化內存中,后邊反復使用,不必再從磁盤加載
        parsedData.cache();

        //設置迭代次數
        int numIterations = 100;
        //三種模型進行訓練 
        LinearRegressionModel linearModel = LinearRegressionWithSGD.train(parsedData.rdd(), numIterations);
        RidgeRegressionModel ridgeModel = RidgeRegressionWithSGD.train(parsedData.rdd(), numIterations);
        LassoModel lassoModel = LassoWithSGD.train(parsedData.rdd(), numIterations);
        //打印信息
        print(parsedData, linearModel);
        print(parsedData, ridgeModel);
        print(parsedData, lassoModel);
     
        //預測一條新數據方法,8個特征值
        double[] d = new double[]{1.0, 1.0, 2.0, 1.0, 3.0, -1.0, 1.0, -2.0};
        Vector v = Vectors.dense(d);
        System.out.println("Prediction of linear: "+linearModel.predict(v));
        System.out.println("Prediction of ridge: "+ridgeModel.predict(v));
        System.out.println("Prediction of lasso: "+lassoModel.predict(v));
        

//        //保存模型
//        model.save(sc.sc(),"myModelPath" );
//        //加載模型
//        LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath");
        
        //關閉
        sc.close();
    }

    /**
     * 
     * @discription 統一輸出方法
     * @author liutiti       
     * @created 2017年11月22日 上午10:00:27     
     * @param parsedData
     * @param model
     */
    public static void print(JavaRDD<LabeledPoint> parsedData, GeneralizedLinearModel model) {
        JavaPairRDD<Double, Double> valuesAndPreds = parsedData.mapToPair(point -> {
            double prediction = model.predict(point.features()); //用模型預測訓練數據
            return new Tuple2<>(point.label(), prediction);
        });
        //打印訓練集中的真實值與相對應的預測值
        valuesAndPreds.foreach((Tuple2<Double, Double> t) -> {
            System.out.println("訓練集真實值:"+t._1()+" ,預測值: "+t._2());
        });
        //計算預測值與實際值差值的平方值的均值
        Double MSE = valuesAndPreds.mapToDouble((Tuple2<Double, Double> t) -> Math.pow(t._1() - t._2(), 2)).mean();
        System.out.println(model.getClass().getName() + " training Mean Squared Error = " + MSE);
    }
}

 

提君博客原創

>>提君博客原創  http://www.cnblogs.com/tijun/  <<

 spark官方java api 文檔


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM