機器學習——ALS算法


ALS算法中文名又稱為最小二乘法,在機器學習中,ALS特指使用最小二乘法求解的協同過濾算法中的一種

ALS算法在構建spark推薦系統時,是用的最多的協同過濾算法,集成到了spark中ml庫和mllib庫中(ml庫算法接口基於DataFrames,mllib庫算法接口基於RDDs,ml庫使用越來越普遍)

ALS算法屬於User-Item CF,同時會考慮User和Item兩個方面,是一種同時考慮到用戶和物品的算法

 

  找出基於UXI的“用戶-物品”矩陣如圖:

  找到和“用戶-物品”近似的K維低階矩陣(K值為ALS中的超參,通常范圍取10-200):用戶矩陣->U x K,物品矩陣->I x K,這兩個因子矩陣的乘積,得到的則為原始評級數據的近似值:

原理分析

  ALS實現原理是迭代求解一系列的最小損失值,在每次迭代時,需要固定因子矩陣中的一個,來更新另一個矩陣因子矩陣,之后將更新的矩陣固定住,再更新另一個矩陣,直到模型收斂

記在原始評分矩陣中的用戶Ut和對項目Is的打分Rst,由乘積(VTU)擬合獲得的評分為(VTU)st.則兩者平方誤差為((VTU)st-Rst)2,則經驗誤差可以記為:

  該模型對於每一個用戶特征向量和項目特征向量都是凸的,意味着可以在它所有的U和I 上能達到局部最優,等價於在矩陣U,I,R中各列向量都獨立服從各自的正太分布下的極大似然擬合,所以該模型也被稱為概率矩陣分解模型(PMF)

 

測試:

import com.nj.untils.MySqlHandler
import org.apache.spark.mllib.recommendation.{ALS, Rating}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

object AlsTest {
  //自定義打分標准
  val actToNum=udf{
    (info:String)=>
      info match  {
        case "BROWSE"=>1
        case "COLLECT"=>2
        case "BUYCAR"=>4
        case "ORDERS"=>8
      }
  }

  case class UserAction(act:String,act_time:String,cust_id:String,good_id:String,browse:String)

  def main(args: Array[String]): Unit = {
    val spark=SparkSession.builder().appName("job").master("local[*]").getOrCreate()
   //讀取用戶的操作行為,並讀取
    val dff = spark.sparkContext.textFile("file:///C:/Users/Administrator/Desktop/myact/*.log").cache()
    //將讀入的數據轉為dataframe 計算出每個用戶對該用戶接觸過的商品的評分
    import spark.implicits._
    val df = dff.map(line=>{
      val arr=line.split(" ")
      UserAction(arr(0),arr(1),arr(2),arr(3),arr(4))
    }).toDF().select($"cust_id",$"good_id",actToNum($"act").alias("score"))
        .groupBy("cust_id","good_id").agg(sum($"score").alias("score")).cache()
    //為了放置用戶編號或商品編號中含有非數字情況,所以對所有的商品和用戶編號給一個連續的對應的數字編號后再存到緩存

    val gwnd=Window.partitionBy().orderBy("good_id")
    val cwnd=Window.partitionBy().orderBy("cust_id")
    val goodstab=MySqlHandler.readMySQL(spark,"goods")
        .select($"good_id",row_number().over(gwnd).alias("gid") ).cache()
    val custtab=MySqlHandler.readMySQL(spark,"customs")
        .select($"cust_id",row_number().over(cwnd).alias("uid")).cache()
    //將df 和goodstab以及custtab join 只保留(gid,uid,score)
    val zc = df.join(goodstab,Seq("good_id"),"inner").join(custtab,Seq("cust_id"),"inner")
        .select("gid","uid","score")

    val alldata=zc.rdd.map(row=>{
      Rating(row.getAs("uid").toString.toInt,
        row.getAs("gid").toString.toInt,
        row.getAs("score").toString.toFloat)
    })
   //查看user為200的用戶的所有評分
    val a=alldata.keyBy(_.user).lookup(200)
    //println(a.size)

    //將獲得的Rating集合拆分按照0.2,0.8比例拆分為兩個集合
    val Array(train,test)=alldata.randomSplit(Array(0.8,0.2))  //8成訓練模型 2成測試模型
    //使用8成的數據去訓練模型
//    val model = ALS.train(train,rank = 10,maxIter = 20,implicitPrefs = false)  ml 適合DataFrame


    val model = new ALS().setRank(10).setIterations(10).setLambda(0.01).setImplicitPrefs(false).run(alldata) //mllib 適合RDD算子
    val tj=model.recommendProductsForUsers(10)  //每一個user都拿出打分最高的前10位 得到RDD(Int,Array[Rating])

    tj.flatMap{
      case(user:Int,ratings:Array[Rating])=>
        ratings.map{case (rat:Rating)=>(user,rat.product,rat.rating)}
    }.foreach(println)
    //可以選擇存儲到hdfs
    tj.toDF().write.mode("overwrite").save("path")
    spark.stop()

  }
}

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM