【翻譯】Flink Table Api & SQL — 用戶定義函數


本文翻譯自官網:User-defined Functions  https://ci.apache.org/projects/flink/flink-docs-release-1.9/dev/table/udfs.html 

Flink Table Api & SQL 翻譯目錄

用戶定義函數是一項重要功能,因為它們顯着擴展了查詢的表達能力。

注冊用戶定義的函數

在大多數情況下,必須先注冊用戶定義的函數,然后才能在查詢中使用該函數。無需注冊Scala Table API的函數。

通過調用registerFunction()方法在TableEnvironment中注冊函數。 注冊用戶定義的函數后,會將其插入TableEnvironment的函數目錄中,以便Table API或SQL解析器可以識別並正確轉換它。

請在以下子會話中找到有關如何注冊以及如何調用每種類型的用戶定義函數(ScalarFunction,TableFunction和AggregateFunction)的詳細示例。

標量函數

如果內置函數中未包含所需的標量函數,則可以為Table API和SQL定義自定義的,用戶定義的標量函數。 用戶定義的標量函數將零個,一個或多個標量值映射到新的標量值。

為了定義標量函數,必須擴展org.apache.flink.table.functions中的基類ScalarFunction並實現(一個或多個)評估方法。 標量函數的行為由評估方法確定。 評估方法必須公開聲明並命名為eval。 評估方法的參數類型和返回類型也確定標量函數的參數和返回類型。 評估方法也可以通過實現多種名為eval的方法來重載。 評估方法還可以支持可變參數,例如eval(String ... strs)。

下面的示例演示如何定義自己的哈希碼函數,如何在TableEnvironment中注冊並在查詢中調用它。 請注意,您可以在構造函數之前注冊它的標量函數:

// must be defined in static/object context
class HashCode(factor: Int) extends ScalarFunction {
  def eval(s: String): Int = {
    s.hashCode() * factor
  }
}

val tableEnv = BatchTableEnvironment.create(env)

// use the function in Scala Table API
val hashCode = new HashCode(10)
myTable.select('string, hashCode('string))

// register and use the function in SQL
tableEnv.registerFunction("hashCode", new HashCode(10))
tableEnv.sqlQuery("SELECT string, hashCode(string) FROM MyTable")

默認情況下,評估方法的結果類型由Flink的類型提取工具確定。 這對於基本類型或簡單的POJO就足夠了,但對於更復雜,自定義或復合類型可能是錯誤的。 在這些情況下,可以通過覆蓋ScalarFunction#getResultType()來手動定義結果類型的TypeInformation。

下面的示例顯示一個高級示例,該示例采用內部時間戳表示,並且還以長值形式返回內部時間戳表示。 通過重寫ScalarFunction#getResultType(),我們定義了代碼生成應將返回的long值解釋為Types.TIMESTAMP。

object TimestampModifier extends ScalarFunction {
  def eval(t: Long): Long = {
    t % 1000
  }

  override def getResultType(signature: Array[Class[_]]): TypeInformation[_] = {
    Types.TIMESTAMP
  }
}

Table Function 

與用戶定義的標量函數相似,用戶定義的表函數將零,一個或多個標量值作為輸入參數。 但是,與標量函數相比,它可以返回任意數量的行作為輸出,而不是單個值。 返回的行可能包含一列或多列。

 為了定義表函數,必須擴展org.apache.flink.table.functions中的基類TableFunction並實現(一個或多個)評估方法。 表函數的行為由其評估方法確定。 必須將評估方法聲明為公開並命名為eval。 通過實現多個名為eval的方法,可以重載TableFunction。 評估方法的參數類型確定表函數的所有有效參數。 評估方法還可以支持可變參數,例如eval(String ... strs)。 返回表的類型由TableFunction的通用類型確定。 評估方法使用受保護的collect(T)方法發出輸出行。

在Table API中,表函數與.joinLateral或.leftOuterJoinLateral一起使用。 joinLateral運算符(叉號)將外部表(運算符左側的表)中的每一行與表值函數(位於運算符的右側)產生的所有行進行連接。 leftOuterJoinLateral運算符將外部表(運算符左側的表)中的每一行與表值函數(位於運算符的右側)產生的所有行連接起來,並保留表函數返回的外部行 一個空桌子。 在SQL中,使用帶有CROSS JOIN和LEFT JOIN且帶有ON TRUE連接條件的LATERAL TABLE(<TableFunction>)(請參見下面的示例)。

下面的示例演示如何定義表值函數,如何在TableEnvironment中注冊表值函數以及如何在查詢中調用它。 請注意,可以在注冊表函數之前通過構造函數對其進行配置:

// The generic type "(String, Int)" determines the schema of the returned table as (String, Integer).
class Split(separator: String) extends TableFunction[(String, Int)] {
  def eval(str: String): Unit = {
    // use collect(...) to emit a row.
    str.split(separator).foreach(x => collect((x, x.length)))
  }
}

val tableEnv = BatchTableEnvironment.create(env)
val myTable = ...         // table schema: [a: String]

// Use the table function in the Scala Table API (Note: No registration required in Scala Table API).
val split = new Split("#")
// "as" specifies the field names of the generated table.
myTable.joinLateral(split('a) as ('word, 'length)).select('a, 'word, 'length)
myTable.leftOuterJoinLateral(split('a) as ('word, 'length)).select('a, 'word, 'length)

// Register the table function to use it in SQL queries.
tableEnv.registerFunction("split", new Split("#"))

// Use the table function in SQL with LATERAL and TABLE keywords.
// CROSS JOIN a table function (equivalent to "join" in Table API)
tableEnv.sqlQuery("SELECT a, word, length FROM MyTable, LATERAL TABLE(split(a)) as T(word, length)")
// LEFT JOIN a table function (equivalent to "leftOuterJoin" in Table API)
tableEnv.sqlQuery("SELECT a, word, length FROM MyTable LEFT JOIN LATERAL TABLE(split(a)) as T(word, length) ON TRUE")

重要說明:不要將TableFunction實現為Scala對象。Scala對象是單例對象,將導致並發問題。 

請注意,POJO類型沒有確定的字段順序。因此,您不能使用 AS 來重命名表函數返回的POJO字段

 默認情況下,TableFunction的結果類型由Flink的自動類型提取工具確定。 這對於基本類型和簡單的POJO非常有效,但是對於更復雜,自定義或復合類型可能是錯誤的。 在這種情況下,可以通過重寫TableFunction#getResultType()並返回其TypeInformation來手動指定結果的類型。

下面的示例顯示一個TableFunction的示例,該函數返回需要顯式類型信息的Row類型。 我們通過重寫TableFunction#getResultType()來定義返回的表類型應為RowTypeInfo(String,Integer)。

class CustomTypeSplit extends TableFunction[Row] {
  def eval(str: String): Unit = {
    str.split(" ").foreach({ s =>
      val row = new Row(2)
      row.setField(0, s)
      row.setField(1, s.length)
      collect(row)
    })
  }

  override def getResultType: TypeInformation[Row] = {
    Types.ROW(Types.STRING, Types.INT)
  }
}

聚合函數 

用戶定義的聚合函數(UDAGG)將表(具有一個或多個屬性的一個或多個行)聚合到一個標量值。

 

 

 上圖顯示了聚合的示例。 假設您有一個包含飲料數據的表。 該表由三列組成,即ID,name和price 以及5行。 假設您需要在表格中找到所有飲料中最高的price ,即執行max()匯總。 您將需要檢查5行中的每行,結果將是單個數字值。

用戶定義的聚合函數通過擴展AggregateFunction類來實現。 AggregateFunction的工作原理如下。 首先,它需要一個累加器,它是保存聚合中間結果的數據結構。 通過調用AggregateFunction的createAccumulator()方法來創建一個空的累加器。 隨后,為每個輸入行調用該函數的accumulate()方法以更新累加器。 處理完所有行后,將調用該函數的getValue()方法以計算並返回最終結果。

每種方法都必須使用以下方法AggregateFunction 

  • createAccumulator()
  • accumulate()
  • getValue()

Flink的類型提取工具可能無法識別復雜的數據類型,例如,如果它們不是基本類型或簡單的POJO。 因此,類似於ScalarFunction和TableFunction,AggregateFunction提供了一些方法來指定結果類型的TypeInformation(通過AggregateFunction#getResultType())和累加器的類型(通過AggregateFunction#getAccumulatorType())。

除上述方法外,還有一些可選擇性實現的約定方法。 盡管這些方法中的某些方法使系統可以更有效地執行查詢,但對於某些用例,其他方法是必需的。 例如,如果聚合功能應在會話組窗口的上下文中應用,則必須使用merge()方法(觀察到“連接”它們的行時,兩個會話窗口的累加器必須合並)。

AggregateFunction根據使用情況,需要以下方法 

  • retract()在有界OVER窗口上進行聚合是必需的
  • merge() 許多批處理聚合和會話窗口聚合是必需的。
  • resetAccumulator() 許多批處理聚合是必需的。

必須將AggregateFunction的所有方法聲明為public,而不是靜態的,並且必須完全按上述名稱命名。 方法createAccumulator,getValue,getResultType和getAccumulatorType在AggregateFunction抽象類中定義,而其他方法則是協定方法。 為了定義聚合函數,必須擴展基類org.apache.flink.table.functions.AggregateFunction並實現一個(或多個)累積方法。 累加的方法可以重載不同的參數類型,並支持可變參數。

下面給出了AggregateFunction的所有方法的詳細文檔。

/**
  * Base class for user-defined aggregates and table aggregates.
  *
  * @tparam T   the type of the aggregation result.
  * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the
  *             aggregated values which are needed to compute an aggregation result.
  */
abstract class UserDefinedAggregateFunction[T, ACC] extends UserDefinedFunction {

  /**
    * Creates and init the Accumulator for this (table)aggregate function.
    *
    * @return the accumulator with the initial value
    */
  def createAccumulator(): ACC // MANDATORY

  /**
    * Returns the TypeInformation of the (table)aggregate function's result.
    *
    * @return The TypeInformation of the (table)aggregate function's result or null if the result
    *         type should be automatically inferred.
    */
  def getResultType: TypeInformation[T] = null // PRE-DEFINED

  /**
    * Returns the TypeInformation of the (table)aggregate function's accumulator.
    *
    * @return The TypeInformation of the (table)aggregate function's accumulator or null if the
    *         accumulator type should be automatically inferred.
    */
  def getAccumulatorType: TypeInformation[ACC] = null // PRE-DEFINED
}

/**
  * Base class for aggregation functions. 
  *
  * @tparam T   the type of the aggregation result
  * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the
  *             aggregated values which are needed to compute an aggregation result.
  *             AggregateFunction represents its state using accumulator, thereby the state of the
  *             AggregateFunction must be put into the accumulator.
  */
abstract class AggregateFunction[T, ACC] extends UserDefinedAggregateFunction[T, ACC] {

  /**
    * Processes the input values and update the provided accumulator instance. The method
    * accumulate can be overloaded with different custom types and arguments. An AggregateFunction
    * requires at least one accumulate() method.
    *
    * @param accumulator           the accumulator which contains the current aggregated results
    * @param [user defined inputs] the input value (usually obtained from a new arrived data).
    */
  def accumulate(accumulator: ACC, [user defined inputs]): Unit // MANDATORY

  /**
    * Retracts the input values from the accumulator instance. The current design assumes the
    * inputs are the values that have been previously accumulated. The method retract can be
    * overloaded with different custom types and arguments. This function must be implemented for
    * datastream bounded over aggregate.
    *
    * @param accumulator           the accumulator which contains the current aggregated results
    * @param [user defined inputs] the input value (usually obtained from a new arrived data).
    */
  def retract(accumulator: ACC, [user defined inputs]): Unit // OPTIONAL

  /**
    * Merges a group of accumulator instances into one accumulator instance. This function must be
    * implemented for datastream session window grouping aggregate and dataset grouping aggregate.
    *
    * @param accumulator  the accumulator which will keep the merged aggregate results. It should
    *                     be noted that the accumulator may contain the previous aggregated
    *                     results. Therefore user should not replace or clean this instance in the
    *                     custom merge method.
    * @param its          an [[java.lang.Iterable]] pointed to a group of accumulators that will be
    *                     merged.
    */
  def merge(accumulator: ACC, its: java.lang.Iterable[ACC]): Unit // OPTIONAL
  
  /**
    * Called every time when an aggregation result should be materialized.
    * The returned value could be either an early and incomplete result
    * (periodically emitted as data arrive) or the final result of the
    * aggregation.
    *
    * @param accumulator the accumulator which contains the current
    *                    aggregated results
    * @return the aggregation result
    */
  def getValue(accumulator: ACC): T // MANDATORY

  /**
    * Resets the accumulator for this [[AggregateFunction]]. This function must be implemented for
    * dataset grouping aggregate.
    *
    * @param accumulator  the accumulator which needs to be reset
    */
  def resetAccumulator(accumulator: ACC): Unit // OPTIONAL

  /**
    * Returns true if this AggregateFunction can only be applied in an OVER window.
    *
    * @return true if the AggregateFunction requires an OVER window, false otherwise.
    */
  def requiresOver: Boolean = false // PRE-DEFINED
}

以下示例顯示了怎么使用

  • 定義一個AggregateFunction計算給定列上的加權平均值
  • TableEnvironment注冊函數
  • 在查詢中使用該函數。

為了計算加權平均值,累加器需要存儲所有累加數據的加權和和計數。 在我們的示例中,我們將一個WeightedAvgAccum類定義為累加器。 累加器由Flink的檢查點機制自動備份,並在無法確保一次准確語義的情況下恢復。

我們的WeightedAvg AggregateFunction的accumulate()方法具有三個輸入。 第一個是WeightedAvgAccum累加器,其他兩個是用戶定義的輸入:輸入值ivalue和輸入iweight的權重。 盡管大多數聚合類型都不強制使用retract(),merge()和resetAccumulator()方法,但我們在下面提供了它們作為示例。 請注意,我們在Scala示例中使用了Java基本類型並定義了getResultType()和getAccumulatorType()方法,因為Flink類型提取不適用於Scala類型。

import java.lang.{Long => JLong, Integer => JInteger}
import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1}
import org.apache.flink.api.java.typeutils.TupleTypeInfo
import org.apache.flink.table.api.Types
import org.apache.flink.table.functions.AggregateFunction

/**
 * Accumulator for WeightedAvg.
 */
class WeightedAvgAccum extends JTuple1[JLong, JInteger] {
  sum = 0L
  count = 0
}

/**
 * Weighted Average user-defined aggregate function.
 */
class WeightedAvg extends AggregateFunction[JLong, CountAccumulator] {

  override def createAccumulator(): WeightedAvgAccum = {
    new WeightedAvgAccum
  }
  
  override def getValue(acc: WeightedAvgAccum): JLong = {
    if (acc.count == 0) {
        null
    } else {
        acc.sum / acc.count
    }
  }
  
  def accumulate(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = {
    acc.sum += iValue * iWeight
    acc.count += iWeight
  }

  def retract(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = {
    acc.sum -= iValue * iWeight
    acc.count -= iWeight
  }
    
  def merge(acc: WeightedAvgAccum, it: java.lang.Iterable[WeightedAvgAccum]): Unit = {
    val iter = it.iterator()
    while (iter.hasNext) {
      val a = iter.next()
      acc.count += a.count
      acc.sum += a.sum
    }
  }

  def resetAccumulator(acc: WeightedAvgAccum): Unit = {
    acc.count = 0
    acc.sum = 0L
  }

  override def getAccumulatorType: TypeInformation[WeightedAvgAccum] = {
    new TupleTypeInfo(classOf[WeightedAvgAccum], Types.LONG, Types.INT)
  }

  override def getResultType: TypeInformation[JLong] = Types.LONG
}

// register function
val tEnv: StreamTableEnvironment = ???
tEnv.registerFunction("wAvg", new WeightedAvg())

// use function
tEnv.sqlQuery("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user")

表聚合函數

用戶定義的表聚合函數(UDTAGG)將一個表(具有一個或多個屬性的一個或多個行)聚合到具有多行和多列的結果表中。

 

 

 上圖顯示了表聚合的示例。 假設您有一個包含飲料數據的表。 該表由三列組成,即ID,name 和 price 以及5行。 假設您需要在表格中找到所有飲料中 price 最高的前2個,即執行top2()表匯總。 您將需要檢查5行中的每行,結果將是帶有前2個值的表。

用戶定義的表聚合功能通過擴展TableAggregateFunction類來實現。 TableAggregateFunction的工作原理如下。 首先,它需要一個累加器,它是保存聚合中間結果的數據結構。 通過調用TableAggregateFunction的createAccumulator()方法來創建一個空的累加器。 隨后,為每個輸入行調用該函數的accumulate()方法以更新累加器。 處理完所有行后,將調用該函數的emitValue()方法來計算並返回最終結果。

每種方法都必須使用以下方法TableAggregateFunction

  • createAccumulator()
  • accumulate()

Flink的類型提取工具可能無法識別復雜的數據類型,例如,如果它們不是基本類型或簡單的POJO。 因此,類似於ScalarFunction和TableFunction,TableAggregateFunction提供了一些方法來指定結果類型的TypeInformation(通過TableAggregateFunction#getResultType())和累加器的類型(通過TableAggregateFunction#getAccumulatorType())。

除上述方法外,還有一些可選擇性實現的約定方法。 盡管這些方法中的某些方法使系統可以更有效地執行查詢,但對於某些用例,其他方法是必需的。 例如,如果聚合功能應在會話組窗口的上下文中應用,則必須使用merge()方法(觀察到“連接”它們的行時,兩個會話窗口的累加器必須合並)。

TableAggregateFunction根據使用情況,需要以下方法

  • retract()在有界OVER窗口上進行聚合是必需的
  • merge() 許多批處理聚合和會話窗口聚合是必需的。
  • resetAccumulator() 許多批處理聚合是必需的。
  • emitValue() 是批處理和窗口聚合所必需的。

TableAggregateFunction的以下方法用於提高流作業的性能:

  • emitUpdateWithRetract() 用於發出在撤回模式下已更新的值。

對於emitValue方法,它根據累加器發出完整的數據。 以TopN為例,emitValue每次都會發出所有前n個值。 這可能會給流作業帶來性能問題。 為了提高性能,用戶還可以實現emmitUpdateWithRetract方法來提高性能。 該方法以縮回模式增量輸出數據,即,一旦有更新,我們必須先縮回舊記錄,然后再發送新的更新記錄。 如果所有方法都在表聚合函數中定義,則該方法將優先於emitValue方法使用,因為emitUpdateWithRetract被認為比emitValue更有效,因為它可以增量輸出值。

必須將TableAggregateFunction的所有方法聲明為public,而不是靜態的,並且其命名必須與上述名稱完全相同。 方法createAccumulator,getResultType和getAccumulatorType在TableAggregateFunction的父抽象類中定義,而其他方法則是契約方法。 為了定義表聚合函數,必須擴展基類org.apache.flink.table.functions.TableAggregateFunction並實現一個(或多個)累積方法。 累加的方法可以重載不同的參數類型,並支持可變參數。

下面給出了TableAggregateFunction的所有方法的詳細文檔。

/**
  * Base class for user-defined aggregates and table aggregates.
  *
  * @tparam T   the type of the aggregation result.
  * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the
  *             aggregated values which are needed to compute an aggregation result.
  */
abstract class UserDefinedAggregateFunction[T, ACC] extends UserDefinedFunction {

  /**
    * Creates and init the Accumulator for this (table)aggregate function.
    *
    * @return the accumulator with the initial value
    */
  def createAccumulator(): ACC // MANDATORY

  /**
    * Returns the TypeInformation of the (table)aggregate function's result.
    *
    * @return The TypeInformation of the (table)aggregate function's result or null if the result
    *         type should be automatically inferred.
    */
  def getResultType: TypeInformation[T] = null // PRE-DEFINED

  /**
    * Returns the TypeInformation of the (table)aggregate function's accumulator.
    *
    * @return The TypeInformation of the (table)aggregate function's accumulator or null if the
    *         accumulator type should be automatically inferred.
    */
  def getAccumulatorType: TypeInformation[ACC] = null // PRE-DEFINED
}

/**
  * Base class for table aggregation functions. 
  *
  * @tparam T   the type of the aggregation result
  * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the
  *             aggregated values which are needed to compute an aggregation result.
  *             TableAggregateFunction represents its state using accumulator, thereby the state of
  *             the TableAggregateFunction must be put into the accumulator.
  */
abstract class TableAggregateFunction[T, ACC] extends UserDefinedAggregateFunction[T, ACC] {

  /**
    * Processes the input values and update the provided accumulator instance. The method
    * accumulate can be overloaded with different custom types and arguments. A TableAggregateFunction
    * requires at least one accumulate() method.
    *
    * @param accumulator           the accumulator which contains the current aggregated results
    * @param [user defined inputs] the input value (usually obtained from a new arrived data).
    */
  def accumulate(accumulator: ACC, [user defined inputs]): Unit // MANDATORY

  /**
    * Retracts the input values from the accumulator instance. The current design assumes the
    * inputs are the values that have been previously accumulated. The method retract can be
    * overloaded with different custom types and arguments. This function must be implemented for
    * datastream bounded over aggregate.
    *
    * @param accumulator           the accumulator which contains the current aggregated results
    * @param [user defined inputs] the input value (usually obtained from a new arrived data).
    */
  def retract(accumulator: ACC, [user defined inputs]): Unit // OPTIONAL

  /**
    * Merges a group of accumulator instances into one accumulator instance. This function must be
    * implemented for datastream session window grouping aggregate and dataset grouping aggregate.
    *
    * @param accumulator  the accumulator which will keep the merged aggregate results. It should
    *                     be noted that the accumulator may contain the previous aggregated
    *                     results. Therefore user should not replace or clean this instance in the
    *                     custom merge method.
    * @param its          an [[java.lang.Iterable]] pointed to a group of accumulators that will be
    *                     merged.
    */
  def merge(accumulator: ACC, its: java.lang.Iterable[ACC]): Unit // OPTIONAL
  
  /**
    * Called every time when an aggregation result should be materialized. The returned value
    * could be either an early and incomplete result  (periodically emitted as data arrive) or
    * the final result of the  aggregation.
    *
    * @param accumulator the accumulator which contains the current
    *                    aggregated results
    * @param out         the collector used to output data
    */
  def emitValue(accumulator: ACC, out: Collector[T]): Unit // OPTIONAL

  /**
    * Called every time when an aggregation result should be materialized. The returned value
    * could be either an early and incomplete result (periodically emitted as data arrive) or
    * the final result of the aggregation.
    *
    * Different from emitValue, emitUpdateWithRetract is used to emit values that have been updated.
    * This method outputs data incrementally in retract mode, i.e., once there is an update, we
    * have to retract old records before sending new updated ones. The emitUpdateWithRetract
    * method will be used in preference to the emitValue method if both methods are defined in the
    * table aggregate function, because the method is treated to be more efficient than emitValue
    * as it can outputvalues incrementally.
    *
    * @param accumulator the accumulator which contains the current
    *                    aggregated results
    * @param out         the retractable collector used to output data. Use collect method
    *                    to output(add) records and use retract method to retract(delete)
    *                    records.
    */
  def emitUpdateWithRetract(accumulator: ACC, out: RetractableCollector[T]): Unit // OPTIONAL
 
  /**
    * Collects a record and forwards it. The collector can output retract messages with the retract
    * method. Note: only use it in `emitRetractValueIncrementally`.
    */
  trait RetractableCollector[T] extends Collector[T] {
    
    /**
      * Retract a record.
      *
      * @param record The record to retract.
      */
    def retract(record: T): Unit
  }
}

以下示例顯示了怎么使用

  • 定義一個TableAggregateFunction用於計算給定列的前2個值
  • TableEnvironment注冊函數
  • 在Table API查詢中使用該函數(Table API僅支持TableAggregateFunction)。

要計算前2個值,累加器需要存儲所有已累加數據中的最大2個值。 在我們的示例中,我們定義了一個Top2Accum類作為累加器。 累加器由Flink的檢查點機制自動備份,並在無法確保一次准確語義的情況下恢復。

我們的Top2 TableAggregateFunction的accumulate()方法有兩個輸入。 第一個是Top2Accum累加器,另一個是用戶定義的輸入:輸入值v。盡管merge()方法對於大多數表聚合類型不是強制性的,但我們在下面提供了示例。 請注意,我們在Scala示例中使用了Java基本類型並定義了getResultType()和getAccumulatorType()方法,因為Flink類型提取不適用於Scala類型。

import java.lang.{Integer => JInteger}
import org.apache.flink.table.api.Types
import org.apache.flink.table.functions.TableAggregateFunction

/**
 * Accumulator for top2.
 */
class Top2Accum {
  var first: JInteger = _
  var second: JInteger = _
}

/**
 * The top2 user-defined table aggregate function.
 */
class Top2 extends TableAggregateFunction[JTuple2[JInteger, JInteger], Top2Accum] {

  override def createAccumulator(): Top2Accum = {
    val acc = new Top2Accum
    acc.first = Int.MinValue
    acc.second = Int.MinValue
    acc
  }

  def accumulate(acc: Top2Accum, v: Int) {
    if (v > acc.first) {
      acc.second = acc.first
      acc.first = v
    } else if (v > acc.second) {
      acc.second = v
    }
  }

  def merge(acc: Top2Accum, its: JIterable[Top2Accum]): Unit = {
    val iter = its.iterator()
    while (iter.hasNext) {
      val top2 = iter.next()
      accumulate(acc, top2.first)
      accumulate(acc, top2.second)
    }
  }

  def emitValue(acc: Top2Accum, out: Collector[JTuple2[JInteger, JInteger]]): Unit = {
    // emit the value and rank
    if (acc.first != Int.MinValue) {
      out.collect(JTuple2.of(acc.first, 1))
    }
    if (acc.second != Int.MinValue) {
      out.collect(JTuple2.of(acc.second, 2))
    }
  }
}

// init table
val tab = ...

// use function
tab
  .groupBy('key)
  .flatAggregate(top2('a) as ('v, 'rank))
  .select('key, 'v, 'rank)

以下示例顯示如何使用emitUpdateWithRetract方法僅發出更新。 為了僅發出更新,在我們的示例中,累加器同時保留了舊的和新的前2個值。 注意:如果topN的N大,則保留舊值和新值都可能無效。 解決這種情況的一種方法是將輸入記錄以累加方法存儲到累加器中,然后在emitUpdateWithRetract中執行計算。

import java.lang.{Integer => JInteger}
import org.apache.flink.table.api.Types
import org.apache.flink.table.functions.TableAggregateFunction

/**
 * Accumulator for top2.
 */
class Top2Accum {
  var first: JInteger = _
  var second: JInteger = _
  var oldFirst: JInteger = _
  var oldSecond: JInteger = _
}

/**
 * The top2 user-defined table aggregate function.
 */
class Top2 extends TableAggregateFunction[JTuple2[JInteger, JInteger], Top2Accum] {

  override def createAccumulator(): Top2Accum = {
    val acc = new Top2Accum
    acc.first = Int.MinValue
    acc.second = Int.MinValue
    acc.oldFirst = Int.MinValue
    acc.oldSecond = Int.MinValue
    acc
  }

  def accumulate(acc: Top2Accum, v: Int) {
    if (v > acc.first) {
      acc.second = acc.first
      acc.first = v
    } else if (v > acc.second) {
      acc.second = v
    }
  }

  def emitUpdateWithRetract(
    acc: Top2Accum,
    out: RetractableCollector[JTuple2[JInteger, JInteger]])
  : Unit = {
    if (acc.first != acc.oldFirst) {
      // if there is an update, retract old value then emit new value.
      if (acc.oldFirst != Int.MinValue) {
        out.retract(JTuple2.of(acc.oldFirst, 1))
      }
      out.collect(JTuple2.of(acc.first, 1))
      acc.oldFirst = acc.first
    }
    if (acc.second != acc.oldSecond) {
      // if there is an update, retract old value then emit new value.
      if (acc.oldSecond != Int.MinValue) {
        out.retract(JTuple2.of(acc.oldSecond, 2))
      }
      out.collect(JTuple2.of(acc.second, 2))
      acc.oldSecond = acc.second
    }
  }
}

// init table
val tab = ...

// use function
tab
  .groupBy('key)
  .flatAggregate(top2('a) as ('v, 'rank))
  .select('key, 'v, 'rank)

實施UDF的最佳做法

Table API和SQL代碼生成在內部嘗試盡可能多地使用原始值。 用戶定義的函數可能會通過對象創建,轉換和裝箱帶來很多開銷。 因此,強烈建議將參數和結果類型聲明為基本類型,而不是其框內的類。 Types.DATE和Types.TIME也可以表示為int。 Types.TIMESTAMP可以表示為long。

我們建議用戶定義的函數應使用Java而不是Scala編寫,因為Scala類型對Flink的類型提取器構成了挑戰。

將UDF與 Runtime 集成

 有時,用戶定義的函數可能有必要在實際工作之前獲取全局運行時信息或進行一些設置/清理工作。 用戶定義的函數提供可被覆蓋的open()和close()方法,並提供與DataSet或DataStream API的RichFunction中的方法相似的功能。

open()方法在評估方法之前被調用一次。 最后一次調用評估方法之后調用close()方法。

open()方法提供一個FunctionContext,其中包含有關在其中執行用戶定義的函數的上下文的信息,例如度量標准組,分布式緩存文件或全局作業參數。

通過調用FunctionContext的相應方法可以獲得以下信息:

方法 描述
getMetricGroup() 此並行子任務的度量標准組。
getCachedFile(name) 分布式緩存文件的本地臨時文件副本。
getJobParameter(name, defaultValue) 與給定鍵關聯的全局作業參數值。

以下示例片段顯示了如何FunctionContext在標量函數中使用它來訪問全局job參數:

object hashCode extends ScalarFunction {

  var hashcode_factor = 12

  override def open(context: FunctionContext): Unit = {
    // access "hashcode_factor" parameter
    // "12" would be the default value if parameter does not exist
    hashcode_factor = context.getJobParameter("hashcode_factor", "12").toInt
  }

  def eval(s: String): Int = {
    s.hashCode() * hashcode_factor
  }
}

val tableEnv = BatchTableEnvironment.create(env)

// use the function in Scala Table API
myTable.select('string, hashCode('string))

// register and use the function in SQL
tableEnv.registerFunction("hashCode", hashCode)
tableEnv.sqlQuery("SELECT string, HASHCODE(string) FROM MyTable")

歡迎關注Flink菜鳥公眾號,會不定期更新Flink(開發技術)相關的推文

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM