spark編寫UDF和UDAF


UDF:

一、編寫udf類,在其中定義udf函數

package spark._sql.UDF

import org.apache.spark.sql.functions._

/**
  * AUTHOR Guozy
  * DATE   2019/7/18-9:41
  **/
object udfs {
  def len(str: String): Int = str.length

  def ageThan(age: Int, small: Int): Boolean = age > small

  val ageThaner = udf((age: Int, bigger: Int) => age < bigger)
} 

二、在主方法中進行調用  

package spark._sql

import org.apache.log4j.Logger
import org.apache.spark.sql
import spark._sql.UDF.udfs._
import org.apache.spark.sql.functions._

/**
  * AUTHOR Guozy
  * DATE   2019/7/18-9:42
  **/
object UDFMain {
  val log = Logger.getLogger("UDFMain")

  def main(args: Array[String]): Unit = {
    val ssc = new sql.SparkSession.Builder()
      .master("local[2]")
      .appName(this.getClass.getSimpleName)
      .enableHiveSupport()
      .getOrCreate()

    ssc.sparkContext.setLogLevel("warn")

    val df = ssc.createDataFrame(Seq((22, 1), (24, 1), (11, 2), (15, 2))).toDF("age", "class_id")
    df.createOrReplaceTempView("table")

    ssc.udf.register("len", len _)
    ssc.sql("select age,len(age) as len from table").show(20, false)
    println("=====================================")
    ssc.udf.register("ageThan", ageThan _)
    ssc.sql("select age from table where ageThan(age,15)").show()
    println("=====================================")
    import ssc.implicits._
    val r = ssc.sql("select * from table")
    r.filter(ageThaner($"age", lit(20))).show()
    println("=====================================")

    ssc.stop()
  }
}

  運行結果:

  

  可以看到,以上代碼中一共定義了三個不同的udf函數,分別對三個函數進行說明:

  • len(str: String):該函數使用用來獲取傳入字段的長度,str 即為所需要傳入的字段
    •   在使用的時候,需要現將其進行注冊並賦予其函數名:ssc.udf.register("len", len _),調用的時候直接在sql語句中通過函數名來進行調用
  • ageThan(age: Int, small: Int):該函數式用來比較傳入的age與已有的small大小,返回一個boolean值,該函數需要是用在where條件語句中用來進行過濾使用
    •     在使用的時候,需要現將其進行注冊並賦予其函數名:ssc.udf.register("ageThan", ageThan _),調用的時候直接在sql語句中通過函數名來進行調用
  • ageThaner:該函數跟上面兩個不同,所謂的不同指的是:
    •   定義方式不同:通過使用org.apache.spark.sql.functions._ 中的udf函數在定義的時候就將其注冊好
    •        使用場景不同:使用在dataframe中,用來進行select,filter操作中
    •        對於該函數的第二列來說,如果是常量的話,需要使用org.apache.spark.sql.function._ 中的lit進行包裝,不能將常量直接傳入,否則,程序不認識該常量會報錯,如果是列名的話,則沒問題,使用($"colName")方式即可。

UDAF(弱類型):

  UDAF相對於udf來說稍微麻煩一下,且需要完全理解當中每個函數的含義才可以輕而易舉的寫出符合自己預期的UDAF函數,      

     UDAF需要繼承 UserDefinedAggregateFunction ,並且復寫當中的方法

方法含義說明:

def inputSchema: StructType =

    StructType(Array(StructField("value", IntegerType)))

  inputSchema用來定義,輸入的字段的類型,字段名可以隨便定義,這里定義為value,也可以是其他的,不重要,關鍵是字段類型一定要與所要傳入計算的字段進行對應,且必須使用org.apche.spark.sql.type. _ 中的類型

def bufferSchema: StructType = StructType(Array(

    StructField("count", IntegerType), StructField("ages", DoubleType)))

  bufferSchema用來定義生成中間數據的結果類型,例如在求和的時候,要求a+b+c,相加順序為a+b=ab,ab+c=abc ,ab即為中間結果。

def dataType: DataType = DoubleType

  dataType為函數返回值的類型,例子中,該UDAF最終返回的結果為double類型,這里的類型不能寫成double,要寫成org.apache.spark.sql.type._支持的類型DoubleType.

 def deterministic: Boolean = true

  daterministic 為代表結果是否為確定性的,也就是說,相同的輸入是否有相同的輸出。

def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0
    buffer(1) = 0.0
  }

  initalize 初始化中間結果,即count和ages的初始值。

override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getInt(0) + 1 //更新計數器
    buffer(1) = buffer.getDouble(1) + input.getInt(0) //更新值
  }

  update用來更新中間結果,input為dataframe中的一行,將要合並到buffer中的數據,buffer則為已經進行合並后的中間結果。

def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0)
    buffer1(1) = buffer1.getDouble(1) + buffer2.getDouble(1)
  }

  merge 合並所有分片的結果,buffer2是一個分片的中間結果,buffer1是整個合並過程中的結果。

def evaluate(buffer: Row): Any = {
    buffer.getDouble(1) / buffer.getInt(0)
  }

  evaluate 函數式真正進行計算的函數,計算返回函數的結果,buffer是merge合並后的結果

 

案例需求:求分組中age的平均數

  先上代碼:

一、定義UDAF函數

package spark._sql.UDAF

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

/**
  * AUTHOR Guozy
  * DATE   2019/7/18-14:47
  **/
class udafs() extends UserDefinedAggregateFunction {

  def inputSchema: StructType =

    StructType(Array(StructField("value", IntegerType)))

  def bufferSchema: StructType = StructType(Array(

    StructField("count", IntegerType), StructField("ages", DoubleType)))

  def dataType: DataType = DoubleType

  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0
    buffer(1) = 0.0
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getInt(0) + 1 //更新計數器
    buffer(1) = buffer.getDouble(1) + input.getInt(0) //更新值
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0)
    buffer1(1) = buffer1.getDouble(1) + buffer2.getDouble(1)
  }

  def evaluate(buffer: Row): Any = {
    buffer.getDouble(1) / buffer.getInt(0)
  }
}

二、主函數引用:

package spark._sql.UDF

import org.apache.spark.sql
import org.apache.spark.sql.functions._
import spark._sql.UDAF.udafs

/**
  * AUTHOR Guozy
  * DATE   2019/7/19-16:04
  **/
object UDAFMain {
  def main(args: Array[String]): Unit = {
    val ssc = new sql.SparkSession.Builder()
      .master("local[2]")
      .appName(this.getClass.getSimpleName)
      .enableHiveSupport()
      .getOrCreate()

    ssc.sparkContext.setLogLevel("warn")

    val ageDF = ssc.createDataFrame(Seq((22, 1), (24, 1), (11, 2), (15, 2))).toDF("age", "class_id")
    ssc.udf.register("avgage", new udafs)
    ageDF.createOrReplaceTempView("table")
    ssc.sql("select avgage(age) from table group by class_id").show()

    ssc.stop()
  }
}

 運行結果:

  

UDAF(強類型)

  關於UDAF函數,一種是關於上面所描述的弱類型聚合函數,弱類型聚合函數只能是在sql數據中進行使用,在使用的過程中,對於傳入的值的類型,如果有問題,只有在程序運行的時候才能進行發現。這樣的話,靈活性不是很高。如果能夠在編譯的時候就對傳入的類型進行限定,並且輸入類型以及輸出類型都是可以有我們自己定義,這樣的相對來說就靈活許多了,而且在生產中使用的也是比較多的。這就是接下來要說的強類型的UDAF,但是有一點需要注意的是,強類型的UDAF不能在sql語句中使用,只能在DLS語句中使用

  自定義強類型的UDFA需要繼承 Aggregator 這個類,與弱類型聚合函數有點區別

  接下來看一下該類中有哪些方法:

  

abstract class Aggregator[-IN, BUF, OUT] extends Serializable {

  /**
   * A zero value for this aggregation. Should satisfy the property that any b + zero = b.
   * @since 1.6.0
   * 初始化緩沖區中的對象
   */
  def zero: BUF

  /**
   * Combine two values to produce a new value.  For performance, the function may modify `b` and
   * return it instead of constructing new object for b.
   * 更新緩沖區中的數據
   * @since 1.6.0
   */
  def reduce(b: BUF, a: IN): BUF

  /**
   * Merge two intermediate values.
   * 合並緩沖區
   * @since 1.6.0
   */
  def merge(b1: BUF, b2: BUF): BUF

  /**
   * Transform the output of the reduction.
   * 實現真正的計算
   * @since 1.6.0
   */
  def finish(reduction: BUF): OUT

  /**
   * Specifies the `Encoder` for the intermediate value type.
   * 緩沖區編碼方式,如果是自定義類型,就是用 Encoders.product
   * @since 2.0.0
   */
  def bufferEncoder: Encoder[BUF]

  /**
   * Specifies the `Encoder` for the final output value type.
   * 最終結果的編碼方式,如果是原生的類型,就用原生的類型,比如Encoders.scalaDouble,等等
   * @since 2.0.0
   */
  def outputEncoder: Encoder[OUT]

  /**
   * Returns this `Aggregator` as a `TypedColumn` that can be used in `Dataset`.
   * operations.
   * @since 1.6.0
   */
  def toColumn: TypedColumn[IN, OUT] = {
    implicit val bEncoder = bufferEncoder
    implicit val cEncoder = outputEncoder

    val expr =
      AggregateExpression(
        TypedAggregateExpression(this),
        Complete,
        isDistinct = false)

    new TypedColumn[IN, OUT](expr, encoderFor[OUT])
  }
}

編碼實現:

  需求:實現求平均年齡,並返回一個自定義的類型

  步驟:

  一、自定義UDAF類

import org.apache.spark.sql.expressions.Aggregator

case class user(name: String, age: Int)

case class avgAggBuffer(var sum: Long, var count: Int)

// 自定義聚合函數,強類型,這里的返回值,我們可以自定義,不一定非要是Double,也可以是自定義封裝類型
class aggerageUDAF extends Aggregator[user, avgAggBuffer, Double] {

  import org.apache.spark.sql.{Encoder, Encoders}

  // 初始化緩沖區的對象
  override def zero: avgAggBuffer = {
    avgAggBuffer(0, 0)
  }

  // 更新緩沖區的數據
  override def reduce(b: avgAggBuffer, a: user): avgAggBuffer = {
    b.sum = b.sum + a.age
    b.count = b.count + 1
    b
  }

  // 合並不同的緩沖區
  override def merge(b1: avgAggBuffer, b2: avgAggBuffer): avgAggBuffer = {
    b1.sum = b1.sum + b2.sum
    b1.count = b1.count + b2.count
    b1
  }

  // 完成計算
  override def finish(reduction: avgAggBuffer): Double = {
    reduction.sum.toDouble / reduction.count
  }

  // 如果是自定義的類型則用該方式進行轉碼,基本固定
  override def bufferEncoder: Encoder[avgAggBuffer] = Encoders.product

  // 轉碼,如果是原生類型,則直接進行轉碼
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

  二、主類調用

object customeUDAF {
  def main(args: Array[String]): Unit = {

    val ssc = SparkSession
      .builder()
      .master("local[2]")
      .appName(this.getClass.getSimpleName)
      .enableHiveSupport()
      .getOrCreate()
    val sc = ssc.sparkContext
    sc.setLogLevel("error")

    import org.apache.spark.sql.{Dataset, TypedColumn}
    import ssc.implicits._
    val rdd1 = sc.parallelize(List(("xl",20),("xh",30),("xw",40))).toDF("name","age")

    val dataset: Dataset[user] = rdd1.as[user]

    // 注冊聚合函數
    val aggUDF = new aggerageUDAF
    // 將聚合函數轉換為查詢列
    val cols: TypedColumn[user, Double] = aggUDF.toColumn.name("avgAge")
    // 只能通過DSL語句進行使用
    dataset.select(cols).show()

    ssc.stop()
  }
}

 運行結果

  

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM