Spark筆記之使用UDAF(User Defined Aggregate Function)


 

一、UDAF簡介

先解釋一下什么是UDAF(User Defined Aggregate Function),即用戶定義的聚合函數,聚合函數和普通函數的區別是什么呢,普通函數是接受一行輸入產生一個輸出,聚合函數是接受一組(一般是多行)輸入然后產生一個輸出,即將一組的值想辦法聚合一下。

 

關於UDAF的一個誤區

我們可能下意識的認為UDAF是需要和group by一起使用的,實際上UDAF可以跟group by一起使用,也可以不跟group by一起使用,這個其實比較好理解,聯想到mysql中的max、min等函數,可以:

select max(foo) from foobar group by bar;

表示根據bar字段分組,然后求每個分組的最大值,這時候的分組有很多個,使用這個函數對每個分組進行處理,也可以:

select max(foo) from foobar;

這種情況可以將整張表看做是一個分組,然后在這個分組(實際上就是一整張表)中求最大值。所以聚合函數實際上是對分組做處理,而不關心分組中記錄的具體數量。

 

二、UDAF使用

2.1 繼承UserDefinedAggregateFunction

使用UserDefinedAggregateFunction的套路:

1. 自定義類繼承UserDefinedAggregateFunction,對每個階段方法做實現

2. 在spark中注冊UDAF,為其綁定一個名字

3. 然后就可以在sql語句中使用上面綁定的名字調用

 

下面寫一個計算平均值的UDAF例子,首先定義一個類繼承UserDefinedAggregateFunction:

package cc11001100.spark.sql.udaf

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

object AverageUserDefinedAggregateFunction extends UserDefinedAggregateFunction {

  // 聚合函數的輸入數據結構
  override def inputSchema: StructType = StructType(StructField("input", LongType) :: Nil)

  // 緩存區數據結構
  override def bufferSchema: StructType = StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)

  // 聚合函數返回值數據結構
  override def dataType: DataType = DoubleType

  // 聚合函數是否是冪等的,即相同輸入是否總是能得到相同輸出
  override def deterministic: Boolean = true

  // 初始化緩沖區
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 0L
  }

  // 給聚合函數傳入一條新數據進行處理
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if (input.isNullAt(0)) return
    buffer(0) = buffer.getLong(0) + input.getLong(0)
    buffer(1) = buffer.getLong(1) + 1
  }

  // 合並聚合函數緩沖區
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  // 計算最終結果
  override def evaluate(buffer: Row): Any = buffer.getLong(0).toDouble / buffer.getLong(1)

}

然后注冊並使用它:

package cc11001100.spark.sql.udaf

import org.apache.spark.sql.SparkSession

object SparkSqlUDAFDemo_001 {

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().master("local[*]").appName("SparkStudy").getOrCreate()
    spark.read.json("data/user").createOrReplaceTempView("v_user")
    spark.udf.register("u_avg", AverageUserDefinedAggregateFunction)
    // 將整張表看做是一個分組對求所有人的平均年齡
    spark.sql("select count(1) as count, u_avg(age) as avg_age from v_user").show()
    // 按照性別分組求平均年齡
    spark.sql("select sex, count(1) as count, u_avg(age) as avg_age from v_user group by sex").show()

  }

}

使用到的數據集:

{"id": 1001, "name": "foo", "sex": "man", "age": 20}
{"id": 1002, "name": "bar", "sex": "man", "age": 24}
{"id": 1003, "name": "baz", "sex": "man", "age": 18}
{"id": 1004, "name": "foo1", "sex": "woman", "age": 17}
{"id": 1005, "name": "bar2", "sex": "woman", "age": 19}
{"id": 1006, "name": "baz3", "sex": "woman", "age": 20}

運行結果:

image

image

 

2.2 繼承Aggregator

還有另一種方式就是繼承Aggregator這個類,優點是可以帶類型:

package cc11001100.spark.sql.udaf

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders}

/**
  * 計算平均值
  *
  */
object AverageAggregator extends Aggregator[User, Average, Double] {

  // 初始化buffer
  override def zero: Average = Average(0L, 0L)

  // 處理一條新的記錄
  override def reduce(b: Average, a: User): Average = {
    b.sum += a.age
    b.count += 1L
    b
  }

  // 合並聚合buffer
  override def merge(b1: Average, b2: Average): Average = {
    b1.sum += b2.sum
    b1.count += b2.count
    b1
  }

  // 減少中間數據傳輸
  override def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count

  override def bufferEncoder: Encoder[Average] = Encoders.product

  // 最終輸出結果的類型
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble

}

/**
  * 計算平均值過程中使用的Buffer
  *
  * @param sum
  * @param count
  */
case class Average(var sum: Long, var count: Long) {
}

case class User(id: Long, name: String, sex: String, age: Long) {
}

調用:

package cc11001100.spark.sql.udaf

import org.apache.spark.sql.SparkSession

object AverageAggregatorDemo_001 {

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().master("local[*]").appName("SparkStudy").getOrCreate()
    import spark.implicits._
    val user = spark.read.json("data/user").as[User]
    user.select(AverageAggregator.toColumn.name("avg")).show()

  }

}

運行結果:

image 

 

.


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM