UDAF(用戶自定義聚合函數)求眾數


除了逐行處理數據的udf,還有比較常見的就是聚合多行處理udaf,自定義聚合函數。類比rdd編程就是map和reduce算子的區別。
自定義UDAF,需要extends org.apache.spark.sql.expressions.UserDefinedAggregateFunction,並實現接口中的8個方法。
udaf寫起來比較麻煩,我下面列一個之前寫的取眾數聚合函數,在我們通常在聚合統計的時候可能會受某條臟數據的影響。
舉個栗子:
對於一個app日志聚合的時候,有id與ip,原則上一個id有一個ip,但是在多條數據里有一條ip是錯誤的或者為空的,這時候group能會聚合成兩條數據了就,如果使用max,min對ip也進行聚合,那也不太合理,這時候可以進行投票,去類似多數對結果,從而聚合后只有一個設備。
廢話少說,上代碼:
import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ /** * Description: 自定義聚合函數:眾數(取列內頻率最高的一條) */

class UDAFGetMode extends UserDefinedAggregateFunction { override def inputSchema: StructType = { StructType(StructField("inputStr", StringType, true) :: Nil) } override def bufferSchema: StructType = { StructType(StructField("bufferMap", MapType(keyType = StringType, valueType = IntegerType), true) :: Nil) } override def dataType: DataType = StringType override def deterministic: Boolean = false

  //初始化map
  override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = scala.collection.immutable.Map[String, Int]() } //如果包含這個key則value+1,否則寫入key,value=1
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { val key = input.getAs[String](0) val immap = buffer.getAs[Map[String, Int]](0) val bufferMap = scala.collection.mutable.Map[String, Int](immap.toSeq: _*) val ret = if (bufferMap.contains(key)) { // val new_value = bufferMap.get(key).get + 1
      val new_value = bufferMap(key) + 1 bufferMap.put(key, new_value) bufferMap } else { bufferMap.put(key, 1) bufferMap } buffer.update(0, ret) } override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { //合並兩個map 相同的key的value累加
    val tempMap = (buffer1.getAs[Map[String, Int]](0) /: buffer2.getAs[Map[String, Int]](0)) { case (map, (k, v)) => map + (k -> (v + map.getOrElse(k, 0))) } buffer1.update(0, tempMap) } override def evaluate(buffer: Row): Any = { //返回值最大的key
    var max_value = 0
    var max_key = "" buffer.getAs[Map[String, Int]](0).foreach({ x => val key = x._1 val value = x._2 if (value > max_value) { max_value = value max_key = key } }) max_key } }

測試類:

object UDAFTest { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().master("local").appName(this.getClass.getSimpleName).getOrCreate() spark.udf.register("get_mode", new UDAFGetMode) import spark.implicits._ val df = Seq( (1, "10.10.1.1", "start"), (1, "10.10.1.1", "search"), (2, "123.123.123.1", "search"), (1, "10.10.1.0", "stop"), (2, "123.123.123.1", "start") ).toDF("id", "ip", "action") df.createTempView("tb") spark.sql(s"select id,get_mode(ip) as u_ip,count(*) as cnt from tb group by id").show() } }

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM