1 標量函數
自定義標量函數可以把 0 到多個標量值映射成 1 個標量值,數據類型里列出的任何數據類型都可作為求值方法的參數和返回值類型。
想要實現自定義標量函數,你需要擴展 org.apache.flink.table.functions
里面的 ScalarFunction
並且實現一個或者多個求值方法。標量函數的行為取決於你寫的求值方法。求值方法必須是 public
的,而且名字必須是 eval
。
下面的例子展示了如何實現一個求哈希值的函數並在查詢里調用它,詳情可參考開發指南:
import org.apache.flink.table.annotation.InputGroup import org.apache.flink.table.api._ import org.apache.flink.table.functions.ScalarFunction class HashFunction extends ScalarFunction { // 接受任意類型輸入,返回 INT 型輸出 def eval(@DataTypeHint(inputGroup = InputGroup.ANY) o: AnyRef): Int { return o.hashCode(); } } val env = TableEnvironment.create(...) // 在 Table API 里不經注冊直接“內聯”調用函數 env.from("MyTable").select(call(classOf[HashFunction], $"myField")) // 注冊函數 env.createTemporarySystemFunction("HashFunction", classOf[HashFunction]) // 在 Table API 里調用注冊好的函數 env.from("MyTable").select(call("HashFunction", $"myField")) // 在 SQL 里調用注冊好的函數 env.sqlQuery("SELECT HashFunction(myField) FROM MyTable")
如果你打算使用 Python 實現或調用標量函數,詳情可參考 Python 標量函數。
2 表值函數
跟自定義標量函數一樣,自定義表值函數的輸入參數也可以是 0 到多個標量。但是跟標量函數只能返回一個值不同的是,它可以返回任意多行。返回的每一行可以包含 1 到多列,如果輸出行只包含 1 列,會省略結構化信息並生成標量值,這個標量值在運行階段會隱式地包裝進行里。
要定義一個表值函數,你需要擴展 org.apache.flink.table.functions
下的 TableFunction
,可以通過實現多個名為 eval
的方法對求值方法進行重載。像其他函數一樣,輸入和輸出類型也可以通過反射自動提取出來。表值函數返回的表的類型取決於 TableFunction
類的泛型參數 T
,不同於標量函數,表值函數的求值方法本身不包含返回類型,而是通過 collect(T)
方法來發送要輸出的行。
在 Table API 中,表值函數是通過 .joinLateral(...)
或者 .leftOuterJoinLateral(...)
來使用的。joinLateral
算子會把外表(算子左側的表)的每一行跟跟表值函數返回的所有行(位於算子右側)進行 (cross)join。leftOuterJoinLateral
算子也是把外表(算子左側的表)的每一行跟表值函數返回的所有行(位於算子右側)進行(cross)join,並且如果表值函數返回 0 行也會保留外表的這一行。
在 SQL 里面用 JOIN
或者 以 ON TRUE
為條件的 LEFT JOIN
來配合 LATERAL TABLE(<TableFunction>)
的使用。
下面的例子展示了如何實現一個分隔函數並在查詢里調用它,詳情可參考開發指南:
import org.apache.flink.table.annotation.DataTypeHint import org.apache.flink.table.annotation.FunctionHint import org.apache.flink.table.api._ import org.apache.flink.table.functions.TableFunction import org.apache.flink.types.Row @FunctionHint(output = new DataTypeHint("ROW<word STRING, length INT>")) class SplitFunction extends TableFunction[Row] { def eval(str: String): Unit = { // use collect(...) to emit a row str.split(" ").foreach(s => collect(Row.of(s, Int.box(s.length)))) } } val env = TableEnvironment.create(...) // 在 Table API 里不經注冊直接“內聯”調用函數 env .from("MyTable") .joinLateral(call(classOf[SplitFunction], $"myField") .select($"myField", $"word", $"length") env .from("MyTable") .leftOuterJoinLateral(call(classOf[SplitFunction], $"myField")) .select($"myField", $"word", $"length") // 在 Table API 里重命名函數字段 env .from("MyTable") .leftOuterJoinLateral(call(classOf[SplitFunction], $"myField").as("newWord", "newLength")) .select($"myField", $"newWord", $"newLength") // 注冊函數 env.createTemporarySystemFunction("SplitFunction", classOf[SplitFunction]) // 在 Table API 里調用注冊好的函數 env .from("MyTable") .joinLateral(call("SplitFunction", $"myField")) .select($"myField", $"word", $"length") env .from("MyTable") .leftOuterJoinLateral(call("SplitFunction", $"myField")) .select($"myField", $"word", $"length") // 在 SQL 里調用注冊好的函數 env.sqlQuery( "SELECT myField, word, length " + "FROM MyTable, LATERAL TABLE(SplitFunction(myField))"); env.sqlQuery( "SELECT myField, word, length " + "FROM MyTable " + "LEFT JOIN LATERAL TABLE(SplitFunction(myField)) ON TRUE") // 在 SQL 里重命名函數字段 env.sqlQuery( "SELECT myField, newWord, newLength " + "FROM MyTable " + "LEFT JOIN LATERAL TABLE(SplitFunction(myField)) AS T(newWord, newLength) ON TRUE")
如果你打算使用 Scala,不要把表值函數聲明為 Scala object
,Scala object
是單例對象,將導致並發問題。
如果你打算使用 Python 實現或調用表值函數,詳情可參考 Python 表值函數。
3 聚合函數
自定義聚合函數(UDAGG)是把一個表(一行或者多行,每行可以有一列或者多列)聚合成一個標量值
上面的圖片展示了一個聚合的例子。假設你有一個關於飲料的表。表里面有三個字段,分別是 id
、name
、price
,表里有 5 行數據。假設你需要找到所有飲料里最貴的飲料的價格,即執行一個 max()
聚合。你需要遍歷所有 5 行數據,而結果就只有一個數值。
自定義聚合函數是通過擴展 AggregateFunction
來實現的。AggregateFunction
的工作過程如下。首先,它需要一個 accumulator
,它是一個數據結構,存儲了聚合的中間結果。通過調用 AggregateFunction
的 createAccumulator()
方法創建一個空的 accumulator。接下來,對於每一行數據,會調用 accumulate()
方法來更新 accumulator。當所有的數據都處理完了之后,通過調用 getValue
方法來計算和返回最終的結果。
下面幾個方法是每個 AggregateFunction
必須要實現的:
createAccumulator()
accumulate()
getValue()
Flink 的類型推導在遇到復雜類型的時候可能會推導出錯誤的結果,比如那些非基本類型和普通的 POJO 類型的復雜類型。所以跟 ScalarFunction
和 TableFunction
一樣,AggregateFunction
也提供了 AggregateFunction#getResultType()
和 AggregateFunction#getAccumulatorType()
來分別指定返回值類型和 accumulator 的類型,兩個函數的返回值類型也都是 TypeInformation
。
除了上面的方法,還有幾個方法可以選擇實現。這些方法有些可以讓查詢更加高效,而有些是在某些特定場景下必須要實現的。例如,如果聚合函數用在會話窗口(當兩個會話窗口合並的時候需要 merge 他們的 accumulator)的話,merge()
方法就是必須要實現的。
AggregateFunction
的以下方法在某些場景下是必須實現的:
retract()
在 boundedOVER
窗口中是必須實現的。merge()
在許多批式聚合和會話窗口聚合中是必須實現的。resetAccumulator()
在許多批式聚合中是必須實現的。
AggregateFunction
的所有方法都必須是 public
的,不能是 static
的,而且名字必須跟上面寫的一樣。createAccumulator
、getValue
、getResultType
以及 getAccumulatorType
這幾個函數是在抽象類 AggregateFunction
中定義的,而其他函數都是約定的方法。如果要定義一個聚合函數,你需要擴展 org.apache.flink.table.functions.AggregateFunction
,並且實現一個(或者多個)accumulate
方法。accumulate
方法可以重載,每個方法的參數類型不同,並且支持變長參數。
AggregateFunction
的所有方法的詳細文檔如下。

/** * Base class for user-defined aggregates and table aggregates. * * @tparam T the type of the aggregation result. * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the * aggregated values which are needed to compute an aggregation result. */ abstract class UserDefinedAggregateFunction[T, ACC] extends UserDefinedFunction { /** * Creates and init the Accumulator for this (table)aggregate function. * * @return the accumulator with the initial value */ def createAccumulator(): ACC // MANDATORY /** * Returns the TypeInformation of the (table)aggregate function's result. * * @return The TypeInformation of the (table)aggregate function's result or null if the result * type should be automatically inferred. */ def getResultType: TypeInformation[T] = null // PRE-DEFINED /** * Returns the TypeInformation of the (table)aggregate function's accumulator. * * @return The TypeInformation of the (table)aggregate function's accumulator or null if the * accumulator type should be automatically inferred. */ def getAccumulatorType: TypeInformation[ACC] = null // PRE-DEFINED } /** * Base class for aggregation functions. * * @tparam T the type of the aggregation result * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the * aggregated values which are needed to compute an aggregation result. * AggregateFunction represents its state using accumulator, thereby the state of the * AggregateFunction must be put into the accumulator. */ abstract class AggregateFunction[T, ACC] extends UserDefinedAggregateFunction[T, ACC] { /** * Processes the input values and update the provided accumulator instance. The method * accumulate can be overloaded with different custom types and arguments. An AggregateFunction * requires at least one accumulate() method. * * @param accumulator the accumulator which contains the current aggregated results * @param [user defined inputs] the input value (usually obtained from a new arrived data). */ def accumulate(accumulator: ACC, [user defined inputs]): Unit // MANDATORY /** * Retracts the input values from the accumulator instance. The current design assumes the * inputs are the values that have been previously accumulated. The method retract can be * overloaded with different custom types and arguments. This function must be implemented for * datastream bounded over aggregate. * * @param accumulator the accumulator which contains the current aggregated results * @param [user defined inputs] the input value (usually obtained from a new arrived data). */ def retract(accumulator: ACC, [user defined inputs]): Unit // OPTIONAL /** * Merges a group of accumulator instances into one accumulator instance. This function must be * implemented for datastream session window grouping aggregate and dataset grouping aggregate. * * @param accumulator the accumulator which will keep the merged aggregate results. It should * be noted that the accumulator may contain the previous aggregated * results. Therefore user should not replace or clean this instance in the * custom merge method. * @param its an [[java.lang.Iterable]] pointed to a group of accumulators that will be * merged. */ def merge(accumulator: ACC, its: java.lang.Iterable[ACC]): Unit // OPTIONAL /** * Called every time when an aggregation result should be materialized. * The returned value could be either an early and incomplete result * (periodically emitted as data arrive) or the final result of the * aggregation. * * @param accumulator the accumulator which contains the current * aggregated results * @return the aggregation result */ def getValue(accumulator: ACC): T // MANDATORY /** * Resets the accumulator for this [[AggregateFunction]]. This function must be implemented for * dataset grouping aggregate. * * @param accumulator the accumulator which needs to be reset */ def resetAccumulator(accumulator: ACC): Unit // OPTIONAL /** * Returns true if this AggregateFunction can only be applied in an OVER window. * * @return true if the AggregateFunction requires an OVER window, false otherwise. */ def requiresOver: Boolean = false // PRE-DEFINED }
下面的例子展示了如何:
- 定義一個聚合函數來計算某一列的加權平均,
- 在
TableEnvironment
中注冊函數, - 在查詢中使用函數。
為了計算加權平均值,accumulator 需要存儲加權總和以及數據的條數。在我們的例子里,我們定義了一個類 WeightedAvgAccum
來作為 accumulator。Flink 的 checkpoint 機制會自動保存 accumulator,在失敗時進行恢復,以此來保證精確一次的語義。
我們的 WeightedAvg
(聚合函數)的 accumulate
方法有三個輸入參數。第一個是 WeightedAvgAccum
accumulator,另外兩個是用戶自定義的輸入:輸入的值 ivalue
和 輸入的權重 iweight
。盡管 retract()
、merge()
、resetAccumulator()
這幾個方法在大多數聚合類型中都不是必須實現的,我們也在樣例中提供了他們的實現。請注意我們在 Scala 樣例中也是用的是 Java 的基礎類型,並且定義了 getResultType()
和 getAccumulatorType()
,因為 Flink 的類型推導對於 Scala 的類型推導做的不是很好。
import java.lang.{Long => JLong, Integer => JInteger} import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1} import org.apache.flink.api.java.typeutils.TupleTypeInfo import org.apache.flink.table.api.Types import org.apache.flink.table.functions.AggregateFunction /** * Accumulator for WeightedAvg. */ class WeightedAvgAccum extends JTuple1[JLong, JInteger] { sum = 0L count = 0 } /** * Weighted Average user-defined aggregate function. */ class WeightedAvg extends AggregateFunction[JLong, CountAccumulator] { override def createAccumulator(): WeightedAvgAccum = { new WeightedAvgAccum } override def getValue(acc: WeightedAvgAccum): JLong = { if (acc.count == 0) { null } else { acc.sum / acc.count } } def accumulate(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = { acc.sum += iValue * iWeight acc.count += iWeight } def retract(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = { acc.sum -= iValue * iWeight acc.count -= iWeight } def merge(acc: WeightedAvgAccum, it: java.lang.Iterable[WeightedAvgAccum]): Unit = { val iter = it.iterator() while (iter.hasNext) { val a = iter.next() acc.count += a.count acc.sum += a.sum } } def resetAccumulator(acc: WeightedAvgAccum): Unit = { acc.count = 0 acc.sum = 0L } override def getAccumulatorType: TypeInformation[WeightedAvgAccum] = { new TupleTypeInfo(classOf[WeightedAvgAccum], Types.LONG, Types.INT) } override def getResultType: TypeInformation[JLong] = Types.LONG } // 注冊函數 val tEnv: StreamTableEnvironment = ??? tEnv.registerFunction("wAvg", new WeightedAvg()) // 使用函數 tEnv.sqlQuery("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user")