一:自定義函數分類
在Spark中,也支持Hive中的自定義函數。自定義函數大致可以分為三種:
1.UDF(User-Defined-Function),即最基本的自定義函數,類似to_char,to_date等
2.UDAF(User- Defined Aggregation Funcation),用戶自定義聚合函數,類似在group by之后使用的sum,avg等
3.UDTF(User-Defined Table-Generating Functions),用戶自定義生成函數,有點像stream里面的flatMap
二:自定義函數的使用UDF
(一)定義case class
case class Emp(empno:Int,ename:String,job:String,mgr:String,hiredate:String,sal:Int,comm:String,deptno:Int)
(二)導入emp.csv的文件
val lineRDD = sc.textFile("/emp.csv").map(_.split(","))
(三)生成DataFrame
val allEmp = lineRDD.map(x=>Emp(x(0).toInt,x(1),x(2),x(3),x(4),x(5).toInt,x(6),x(7).toInt)) val empDF = allEmp.toDF
(四)注冊成一個臨時視圖
empDF.createOrReplaceTempView("emp")
(五)自定義一個函數,拼加字符串
spark.sqlContext.udf.register("concatstr",(s1:String,s2:String)=>s1+"***"+s2)
(六)調用自定義函數,將ename和job這兩個字段拼接在一起
spark.sql("select concatstr(ename,job) from emp").show
三:用戶自定義聚合函數UDAF,需要繼承UserDefinedAggregateFunction類,並實現其中的8個方法
UDAF就是用戶自定義聚合函數,比如平均值,最大最小值,累加,拼接等。這里以求平均數為例,並用Java實現
(一)實現自定義聚合函數
package SparkUDAF; import org.apache.spark.sql.Row; import org.apache.spark.sql.expressions.MutableAggregationBuffer; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import java.util.ArrayList; import java.util.List; public class MyAvg extends UserDefinedAggregateFunction { @Override public StructType inputSchema() { //輸入數據的類型,輸入的是字符串 List<StructField> structFields = new ArrayList<>(); structFields.add(DataTypes.createStructField("InputData", DataTypes.StringType, true)); return DataTypes.createStructType(structFields); } @Override public StructType bufferSchema() { //聚合操作時,所處理的數據的數據類型,在這個例子里求平均數,要先求和(Sum),然后除以個數(Amount),所以這里需要處理兩個字段 //注意因為用了ArrayList,所以是有序的 List<StructField> structFields = new ArrayList<>(); structFields.add(DataTypes.createStructField("Amount", DataTypes.IntegerType, true)); structFields.add(DataTypes.createStructField("Sum", DataTypes.IntegerType, true)); return DataTypes.createStructType(structFields); } @Override public DataType dataType() { //UDAF計算后的返回值類型 return DataTypes.IntegerType; } @Override public boolean deterministic() { //判斷輸入和輸出的類型是否一致,如果返回的是true則表示一致,false表示不一致,自行設置 return false; } @Override public void initialize(MutableAggregationBuffer buffer) { /* 對輔助字段進行初始化,就是上面定義的field1和field2 第一個輔助字段的下標為0,初始值為0 第二個輔助字段的下標為1,初始值為0 */ buffer.update(0, 0); buffer.update(1, 0); } @Override public void update(MutableAggregationBuffer buffer, Row input) { /* update可以認為是在每一個節點上都會對數據執行的操作,UDAF函數執行的時候,數據會被分發到每一個節點上,就是每一個分區 buffer.getInt(0)獲取的是上一次聚合后的值,input就是當前獲取的數據 */ //修改輔助字段的值,buffer.getInt(x)獲取的是上一次聚合后的值,x表示 buffer.update(0, buffer.getInt(0) + 1); //表示某個數字的個數 buffer.update(1, buffer.getInt(1) + Integer.parseInt(input.getString(0))); //表示某個數字的總和 } @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { /* merge:對每個分區的結果進行合並,每個分布式的節點上做完update之后就要做一個全局合並的操作 合並每一個update操作的結果,將各個節點上的數據合並起來 buffer1.getInt(0) : 上一次聚合后的值 buffer2.getInt(0) : 這次計算傳入進來的update的結果 */ //對第一個字段Amount進行求和,求出總個數 buffer1.update(0, buffer1.getInt(0) + buffer2.getInt(0)); //對第二個字段Sum進行求和,求出總和 buffer1.update(1, buffer1.getInt(1) + buffer2.getInt(1)); } @Override public Object evaluate(Row buffer) { //表示最終計算的結果,第二個參數表示和值,第一個參數表示個數 return buffer.getInt(1) / buffer.getInt(0); } }
(二)注冊並使用UDAF
package SparkUDAF; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import java.util.ArrayList; import java.util.List; public class TestMain { public static void main(String[] args) { SparkConf conf =new SparkConf(); conf.setMaster("local").setAppName("MyAvg"); JavaSparkContext sc= new JavaSparkContext(conf); //得到SQLContext對象 SQLContext sqlContext = new SQLContext(sc); //注冊自定義函數 sqlContext.udf().register("my_avg",new MyAvg()); //讀入數據 JavaRDD<String> lines = sc.textFile("d:\\test.txt"); //分詞 JavaRDD<Row> rows=lines.map(line-> RowFactory.create(line.split("\\^"))); //定義schema的結構,a字段是字母,b字段是value List<StructField> structFields = new ArrayList<>(); structFields.add(DataTypes.createStructField("a",DataTypes.StringType,true)); structFields.add(DataTypes.createStructField("b",DataTypes.StringType,true)); StructType structType = DataTypes.createStructType(structFields); //創建DataFrame Dataset ds=sqlContext.createDataFrame(rows,structType); ds.registerTempTable("test"); //執行查詢 sqlContext.sql("select a,my_avg(b) from test group by a").show(); sc.stop(); } }