Spark機器學習(1):線性回歸算法


線性回歸算法,是利用數理統計中回歸分析,來確定兩種或兩種以上變量間相互依賴的定量關系的一種統計分析方法。

1. 梯度下降法

線性回歸可以使用最小二乘法,但是速度比較慢,因此一般使用梯度下降法(Gradient Descent),梯度下降法又分為批量梯度下降法(Batch Gradient Descent)和隨機梯度下降法(Stochastic Gradient Descent)。批量梯度下降法每次迭代需要使用訓練集里面的所有數據,當訓練集數據量較大時,速度就很慢;隨機梯度下降法每次迭代只需要一個樣本的數據,速度較快,對於大數據集,可能只需要使用少部分數據就達到收斂值,雖然有可能在最小值周圍震盪,但是大多數情況下效果不錯,所以,一般使用隨機梯度下降法。

2. Mllib的線性回歸

Mllib的線性回歸采用的是隨機梯度下降法。直接上代碼:

import org.apache.log4j.{ Level, Logger }
import org.apache.spark.{ SparkConf, SparkContext }
import org.apache.spark.mllib.regression.LinearRegressionWithSGD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors

object LinearRegression {

  def main(args: Array[String]) {
    // 設置運行環境
    val conf = new SparkConf().setAppName("Linear Regression Test").setMaster("spark://master:7077").setJars(Seq("E:\\Intellij\\Projects\\MachineLearning\\MachineLearning.jar"))
    val sc = new SparkContext(conf)
    Logger.getRootLogger.setLevel(Level.WARN)

    //讀取樣本數據,生成RDD
    val data_path = "hdfs://master:9000/ml/data/lpsa.data"
    val dataRDD = sc.textFile(data_path)
    val examples = dataRDD.map { line =>
      val parts = line.split(',')
      LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
    }.cache()// 迭代次數
    val numIterations = 100
    // 步長
    val stepSize = 0.5
    // 選取樣本的比例
    val miniBatchFraction = 1.0
    // 用隨機梯度下降模型訓練
    val sgdModel = LinearRegressionWithSGD.train(examples, numIterations, stepSize, miniBatchFraction)

    // 對樣本進行測試
    val prediction = sgdModel.predict(examples.map(_.features))
    val predictionAndLabel = prediction.zip(examples.map(_.label))
    // 選取前100個樣本
    val show_predict = predictionAndLabel.take(100)
    println("Prediction" + "\t" + "Label" + "\t" + "Diff")
    for (i <- 0 to show_predict.length - 1) {
      val diff = show_predict(i)._1-show_predict(i)._2
      println(show_predict(i)._1 + "\t" + show_predict(i)._2 + "\t" + diff)
    }

  }

}

部分運行結果:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM