sparksql 自定義用戶函數(UDF)


自定義用戶函數有兩種方式,區別:是否使用強類型,參考demo:https://github.com/asker124143222/spark-demo

1、不使用強類型,繼承UserDefinedAggregateFunction

package com.home.spark

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._


object Ex_sparkUDAF {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf(true).setAppName("spark udf").setMaster("local[*]")
    val spark = SparkSession.builder().config(conf).getOrCreate()


    //自定義聚合函數
    //創建聚合函數對象
    val myUdaf = new MyAgeAvgFunc

    //注冊自定義函數
    spark.udf.register("ageAvg",myUdaf)

    //使用聚合函數
    val frame: DataFrame = spark.read.json("input/userinfo.json")
    frame.createOrReplaceTempView("userinfo")
    spark.sql("select ageAvg(age) from userinfo").show()

    spark.stop()
  }
}

//聲明自定義函數
//實現對年齡的平均,數據如:{ "name": "tom", "age" : 20}
class MyAgeAvgFunc extends UserDefinedAggregateFunction {
  //函數輸入的數據結構,本例中只有年齡是輸入數據
  override def inputSchema: StructType = {
    new StructType().add("age", LongType)
  }

  //計算時的數據結構(緩沖區)
  // 本例中有要計算年齡平均值,必須有兩個計算結構,一個是年齡總計(sum),一個是年齡個數(count)
  override def bufferSchema: StructType = {
    new StructType().add("sum", LongType).add("count", LongType)
  }

  //函數返回的數據類型
  override def dataType: DataType = DoubleType

  //函數是否穩定
  override def deterministic: Boolean = true

  //計算前緩沖區的初始化,結構類似數組,這里緩沖區與之前定義的bufferSchema順序一致
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //sum
    buffer(0) = 0L
    //count
    buffer(1) = 0L
  }

  //根據查詢結果更新緩沖區數據,input是每次進入的數據,其數據結構與之前定義的inputSchema相同
  //本例中每次輸入的數據只有一個就是年齡
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if(input.isNullAt(0)) return
    //sum
    buffer(0) = buffer.getLong(0) + input.getLong(0)

    //count,每次來一個數據加1
    buffer(1) = buffer.getLong(1) + 1
  }

  //將多個節點的緩沖區合並到一起(因為spark是分布式的)
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    //sum
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)

    //count
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  //計算最終結果,本例中就是(sum / count)
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0).toDouble / buffer.getLong(1)
  }
}

2、使用強類型,

package com.home.spark

import org.apache.spark.SparkConf
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.Aggregator


object Ex_sparkUDAF2 {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf(true).setAppName("spark udf class").setMaster("local[*]")
    val spark = SparkSession.builder().config(conf).getOrCreate()

    //rdd轉換成df或者ds需要SparkSession實例的隱式轉換
    //導入隱式轉換,注意這里的spark不是包名,而是SparkSession的對象名
    import spark.implicits._

    //創建聚合函數對象
    val myAvgFunc = new MyAgeAvgClassFunc
    val avgCol: TypedColumn[UserBean, Double] = myAvgFunc.toColumn.name("avgAge")
    val frame = spark.read.json("input/userinfo.json")
    val userDS: Dataset[UserBean] = frame.as[UserBean]
    //應用函數
    userDS.select(avgCol).show()

    spark.stop()
  }
}


case class UserBean(name: String, age: BigInt)

case class AvgBuffer(var sum: BigInt, var count: Int)

//聲明用戶自定義函數(強類型方式)
//繼承Aggregator,設定泛型
//實現方法
class MyAgeAvgClassFunc extends Aggregator[UserBean, AvgBuffer, Double] {
  //初始化緩沖區
  override def zero: AvgBuffer = {
    AvgBuffer(0, 0)
  }

  //聚合數據
  override def reduce(b: AvgBuffer, a: UserBean): AvgBuffer = {
    if(a.age == null) return b
    b.sum = b.sum + a.age
    b.count = b.count + 1

    b
  }

  //緩沖區合並操作
  override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = {
    b1.sum = b1.sum + b2.sum
    b1.count = b1.count + b2.count

    b1
  }

  //完成計算
  override def finish(reduction: AvgBuffer): Double = {
    reduction.sum.toDouble / reduction.count
  }

  override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product

  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

 

繼承Aggregator


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM