在Spark中,也支持Hive中的自定義函數。自定義函數大致可以分為三種:
- UDF(User-Defined-Function),即最基本的自定義函數,類似to_char,to_date等
- UDAF(User- Defined Aggregation Funcation),用戶自定義聚合函數,類似在group by之后使用的sum,avg等
- UDTF(User-Defined Table-Generating Functions),用戶自定義生成函數,有點像stream里面的flatMap
自定義一個UDF函數需要繼承UserDefinedAggregateFunction類,並實現其中的8個方法
示例
import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType} object GetDistinctCityUDF extends UserDefinedAggregateFunction{ /** * 輸入的數據類型 * */ override def inputSchema: StructType = StructType( StructField("status",StringType,true) :: Nil ) /** * 緩存字段類型 * */ override def bufferSchema: StructType = { StructType( Array( StructField("buffer_city_info",StringType,true) ) ) } /** * 輸出結果類型 * */ override def dataType: DataType = StringType /** * 輸入類型和輸出類型是否一致 * */ override def deterministic: Boolean = true /** * 對輔助字段進行初始化 * */ override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer.update(0,"") } /** *修改輔助字段的值 * */ override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { //獲取最后一次的值 var last_str = buffer.getString(0) //獲取當前的值 val current_str = input.getString(0) //判斷最后一次的值是否包含當前的值 if(!last_str.contains(current_str)){ //判斷是否是第一個值,是的話走if賦值,不是的話走else追加 if(last_str.equals("")){ last_str = current_str }else{ last_str += "," + current_str } } buffer.update(0,last_str) } /** *對分區結果進行合並 * buffer1是機器hadoop1上的結果 * buffer2是機器Hadoop2上的結果 * */ override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { var buf1 = buffer1.getString(0) val buf2 = buffer2.getString(0) //將buf2里面存在的數據而buf1里面沒有的數據追加到buf1 //buf2的數據按照,進行切分 for(s <- buf2.split(",")){ if(!buf1.contains(s)){ if(buf1.equals("")){ buf1 = s }else{ buf1 += s } } } buffer1.update(0,buf1) } /** * 最終的計算結果 * */ override def evaluate(buffer: Row): Any = { buffer.getString(0) } }
注冊自定義的UDF函數為臨時函數
def main(args: Array[String]): Unit = { /** * 第一步 創建程序入口 */ val conf = new SparkConf().setAppName("AralHotProductSpark") val sc = new SparkContext(conf) val hiveContext = new HiveContext(sc)
//注冊成為臨時函數 hiveContext.udf.register("get_distinct_city",GetDistinctCityUDF) //注冊成為臨時函數 hiveContext.udf.register("get_product_status",(str:String) =>{ var status = 0 for(s <- str.split(",")){ if(s.contains("product_status")){ status = s.split(":")(1).toInt } } }) }