自定義用戶函數有兩種方式,區別:是否使用強類型,參考demo:https://github.com/asker124143222/spark-demo
1、不使用強類型,繼承UserDefinedAggregateFunction
package com.home.spark import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ object Ex_sparkUDAF { def main(args: Array[String]): Unit = { val conf = new SparkConf(true).setAppName("spark udf").setMaster("local[*]") val spark = SparkSession.builder().config(conf).getOrCreate() //自定義聚合函數 //創建聚合函數對象 val myUdaf = new MyAgeAvgFunc //注冊自定義函數 spark.udf.register("ageAvg",myUdaf) //使用聚合函數 val frame: DataFrame = spark.read.json("input/userinfo.json") frame.createOrReplaceTempView("userinfo") spark.sql("select ageAvg(age) from userinfo").show() spark.stop() } } //聲明自定義函數 //實現對年齡的平均,數據如:{ "name": "tom", "age" : 20} class MyAgeAvgFunc extends UserDefinedAggregateFunction { //函數輸入的數據結構,本例中只有年齡是輸入數據 override def inputSchema: StructType = { new StructType().add("age", LongType) } //計算時的數據結構(緩沖區) // 本例中有要計算年齡平均值,必須有兩個計算結構,一個是年齡總計(sum),一個是年齡個數(count) override def bufferSchema: StructType = { new StructType().add("sum", LongType).add("count", LongType) } //函數返回的數據類型 override def dataType: DataType = DoubleType //函數是否穩定 override def deterministic: Boolean = true //計算前緩沖區的初始化,結構類似數組,這里緩沖區與之前定義的bufferSchema順序一致 override def initialize(buffer: MutableAggregationBuffer): Unit = { //sum buffer(0) = 0L //count buffer(1) = 0L } //根據查詢結果更新緩沖區數據,input是每次進入的數據,其數據結構與之前定義的inputSchema相同 //本例中每次輸入的數據只有一個就是年齡 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { if(input.isNullAt(0)) return //sum buffer(0) = buffer.getLong(0) + input.getLong(0) //count,每次來一個數據加1 buffer(1) = buffer.getLong(1) + 1 } //將多個節點的緩沖區合並到一起(因為spark是分布式的) override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { //sum buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) //count buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1) } //計算最終結果,本例中就是(sum / count) override def evaluate(buffer: Row): Any = { buffer.getLong(0).toDouble / buffer.getLong(1) } }
2、使用強類型,
package com.home.spark import org.apache.spark.SparkConf import org.apache.spark.sql._ import org.apache.spark.sql.expressions.Aggregator object Ex_sparkUDAF2 { def main(args: Array[String]): Unit = { val conf = new SparkConf(true).setAppName("spark udf class").setMaster("local[*]") val spark = SparkSession.builder().config(conf).getOrCreate() //rdd轉換成df或者ds需要SparkSession實例的隱式轉換 //導入隱式轉換,注意這里的spark不是包名,而是SparkSession的對象名 import spark.implicits._ //創建聚合函數對象 val myAvgFunc = new MyAgeAvgClassFunc val avgCol: TypedColumn[UserBean, Double] = myAvgFunc.toColumn.name("avgAge") val frame = spark.read.json("input/userinfo.json") val userDS: Dataset[UserBean] = frame.as[UserBean] //應用函數 userDS.select(avgCol).show() spark.stop() } } case class UserBean(name: String, age: BigInt) case class AvgBuffer(var sum: BigInt, var count: Int) //聲明用戶自定義函數(強類型方式) //繼承Aggregator,設定泛型 //實現方法 class MyAgeAvgClassFunc extends Aggregator[UserBean, AvgBuffer, Double] { //初始化緩沖區 override def zero: AvgBuffer = { AvgBuffer(0, 0) } //聚合數據 override def reduce(b: AvgBuffer, a: UserBean): AvgBuffer = { if(a.age == null) return b b.sum = b.sum + a.age b.count = b.count + 1 b } //緩沖區合並操作 override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = { b1.sum = b1.sum + b2.sum b1.count = b1.count + b2.count b1 } //完成計算 override def finish(reduction: AvgBuffer): Double = { reduction.sum.toDouble / reduction.count } override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product override def outputEncoder: Encoder[Double] = Encoders.scalaDouble }
繼承Aggregator