一、UDAF簡介
先解釋一下什么是UDAF(User Defined Aggregate Function),即用戶定義的聚合函數,聚合函數和普通函數的區別是什么呢,普通函數是接受一行輸入產生一個輸出,聚合函數是接受一組(一般是多行)輸入然后產生一個輸出,即將一組的值想辦法聚合一下。
關於UDAF的一個誤區
我們可能下意識的認為UDAF是需要和group by一起使用的,實際上UDAF可以跟group by一起使用,也可以不跟group by一起使用,這個其實比較好理解,聯想到mysql中的max、min等函數,可以:
select max(foo) from foobar group by bar;
表示根據bar字段分組,然后求每個分組的最大值,這時候的分組有很多個,使用這個函數對每個分組進行處理,也可以:
select max(foo) from foobar;
這種情況可以將整張表看做是一個分組,然后在這個分組(實際上就是一整張表)中求最大值。所以聚合函數實際上是對分組做處理,而不關心分組中記錄的具體數量。
二、UDAF使用
2.1 繼承UserDefinedAggregateFunction
使用UserDefinedAggregateFunction的套路:
1. 自定義類繼承UserDefinedAggregateFunction,對每個階段方法做實現
2. 在spark中注冊UDAF,為其綁定一個名字
3. 然后就可以在sql語句中使用上面綁定的名字調用
下面寫一個計算平均值的UDAF例子,首先定義一個類繼承UserDefinedAggregateFunction:
package cc11001100.spark.sql.udaf import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ object AverageUserDefinedAggregateFunction extends UserDefinedAggregateFunction { // 聚合函數的輸入數據結構 override def inputSchema: StructType = StructType(StructField("input", LongType) :: Nil) // 緩存區數據結構 override def bufferSchema: StructType = StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil) // 聚合函數返回值數據結構 override def dataType: DataType = DoubleType // 聚合函數是否是冪等的,即相同輸入是否總是能得到相同輸出 override def deterministic: Boolean = true // 初始化緩沖區 override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0L buffer(1) = 0L } // 給聚合函數傳入一條新數據進行處理 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { if (input.isNullAt(0)) return buffer(0) = buffer.getLong(0) + input.getLong(0) buffer(1) = buffer.getLong(1) + 1 } // 合並聚合函數緩沖區 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1) } // 計算最終結果 override def evaluate(buffer: Row): Any = buffer.getLong(0).toDouble / buffer.getLong(1) }
然后注冊並使用它:
package cc11001100.spark.sql.udaf import org.apache.spark.sql.SparkSession object SparkSqlUDAFDemo_001 { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().master("local[*]").appName("SparkStudy").getOrCreate() spark.read.json("data/user").createOrReplaceTempView("v_user") spark.udf.register("u_avg", AverageUserDefinedAggregateFunction) // 將整張表看做是一個分組對求所有人的平均年齡 spark.sql("select count(1) as count, u_avg(age) as avg_age from v_user").show() // 按照性別分組求平均年齡 spark.sql("select sex, count(1) as count, u_avg(age) as avg_age from v_user group by sex").show() } }
使用到的數據集:
{"id": 1001, "name": "foo", "sex": "man", "age": 20} {"id": 1002, "name": "bar", "sex": "man", "age": 24} {"id": 1003, "name": "baz", "sex": "man", "age": 18} {"id": 1004, "name": "foo1", "sex": "woman", "age": 17} {"id": 1005, "name": "bar2", "sex": "woman", "age": 19} {"id": 1006, "name": "baz3", "sex": "woman", "age": 20}
運行結果:
2.2 繼承Aggregator
還有另一種方式就是繼承Aggregator這個類,優點是可以帶類型:
package cc11001100.spark.sql.udaf import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.{Encoder, Encoders} /** * 計算平均值 * */ object AverageAggregator extends Aggregator[User, Average, Double] { // 初始化buffer override def zero: Average = Average(0L, 0L) // 處理一條新的記錄 override def reduce(b: Average, a: User): Average = { b.sum += a.age b.count += 1L b } // 合並聚合buffer override def merge(b1: Average, b2: Average): Average = { b1.sum += b2.sum b1.count += b2.count b1 } // 減少中間數據傳輸 override def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count override def bufferEncoder: Encoder[Average] = Encoders.product // 最終輸出結果的類型 override def outputEncoder: Encoder[Double] = Encoders.scalaDouble } /** * 計算平均值過程中使用的Buffer * * @param sum * @param count */ case class Average(var sum: Long, var count: Long) { } case class User(id: Long, name: String, sex: String, age: Long) { }
調用:
package cc11001100.spark.sql.udaf import org.apache.spark.sql.SparkSession object AverageAggregatorDemo_001 { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().master("local[*]").appName("SparkStudy").getOrCreate() import spark.implicits._ val user = spark.read.json("data/user").as[User] user.select(AverageAggregator.toColumn.name("avg")).show() } }
運行結果:
.