SparkSQL自定義函數


一:自定義函數分類

在Spark中,也支持Hive中的自定義函數。自定義函數大致可以分為三種:

1.UDF(User-Defined-Function),即最基本的自定義函數,類似to_char,to_date等
2.UDAF(User- Defined Aggregation Funcation),用戶自定義聚合函數,類似在group by之后使用的sum,avg等
3.UDTF(User-Defined Table-Generating Functions),用戶自定義生成函數,有點像stream里面的flatMap

二:自定義函數的使用UDF

(一)定義case class

         case class Emp(empno:Int,ename:String,job:String,mgr:String,hiredate:String,sal:Int,comm:String,deptno:Int)

(二)導入emp.csv的文件

          val lineRDD = sc.textFile("/emp.csv").map(_.split(","))

(三)生成DataFrame

         val allEmp = lineRDD.map(x=>Emp(x(0).toInt,x(1),x(2),x(3),x(4),x(5).toInt,x(6),x(7).toInt))
         val empDF = allEmp.toDF

(四)注冊成一個臨時視圖

         empDF.createOrReplaceTempView("emp")

(五)自定義一個函數,拼加字符串

         spark.sqlContext.udf.register("concatstr",(s1:String,s2:String)=>s1+"***"+s2)

(六)調用自定義函數,將ename和job這兩個字段拼接在一起

         spark.sql("select concatstr(ename,job) from emp").show

三:用戶自定義聚合函數UDAF,需要繼承UserDefinedAggregateFunction類,並實現其中的8個方法

UDAF就是用戶自定義聚合函數,比如平均值,最大最小值,累加,拼接等。這里以求平均數為例,並用Java實現

(一)實現自定義聚合函數

package SparkUDAF;
 
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
 
import java.util.ArrayList;
import java.util.List;
 
public class MyAvg extends UserDefinedAggregateFunction {
 
    @Override
    public StructType inputSchema() {
        //輸入數據的類型,輸入的是字符串
        List<StructField> structFields = new ArrayList<>();
        structFields.add(DataTypes.createStructField("InputData", DataTypes.StringType, true));
 
        return DataTypes.createStructType(structFields);
    }
 
    @Override
    public StructType bufferSchema() {
 
        //聚合操作時,所處理的數據的數據類型,在這個例子里求平均數,要先求和(Sum),然后除以個數(Amount),所以這里需要處理兩個字段
        //注意因為用了ArrayList,所以是有序的
        List<StructField> structFields = new ArrayList<>();
        structFields.add(DataTypes.createStructField("Amount", DataTypes.IntegerType, true));
        structFields.add(DataTypes.createStructField("Sum", DataTypes.IntegerType, true));
 
        return DataTypes.createStructType(structFields);
    }
 
    @Override
    public DataType dataType() {
        //UDAF計算后的返回值類型
        return DataTypes.IntegerType;
    }
 
    @Override
    public boolean deterministic() {
        //判斷輸入和輸出的類型是否一致,如果返回的是true則表示一致,false表示不一致,自行設置
        return false;
    }
 
    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        /*
        對輔助字段進行初始化,就是上面定義的field1和field2
        第一個輔助字段的下標為0,初始值為0
        第二個輔助字段的下標為1,初始值為0
        */
        buffer.update(0, 0);
        buffer.update(1, 0);
    }
 
    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {
        /*
        update可以認為是在每一個節點上都會對數據執行的操作,UDAF函數執行的時候,數據會被分發到每一個節點上,就是每一個分區
        buffer.getInt(0)獲取的是上一次聚合后的值,input就是當前獲取的數據
        */
 
        //修改輔助字段的值,buffer.getInt(x)獲取的是上一次聚合后的值,x表示
        buffer.update(0, buffer.getInt(0) + 1); //表示某個數字的個數
        buffer.update(1, buffer.getInt(1) + Integer.parseInt(input.getString(0))); //表示某個數字的總和
    }
 
    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        /*
        merge:對每個分區的結果進行合並,每個分布式的節點上做完update之后就要做一個全局合並的操作
        合並每一個update操作的結果,將各個節點上的數據合並起來
        buffer1.getInt(0) : 上一次聚合后的值
        buffer2.getInt(0) : 這次計算傳入進來的update的結果
        */
 
        //對第一個字段Amount進行求和,求出總個數
        buffer1.update(0, buffer1.getInt(0) + buffer2.getInt(0));
        //對第二個字段Sum進行求和,求出總和
        buffer1.update(1, buffer1.getInt(1) + buffer2.getInt(1));
    }
 
    @Override
    public Object evaluate(Row buffer) {
        //表示最終計算的結果,第二個參數表示和值,第一個參數表示個數
        return buffer.getInt(1) / buffer.getInt(0);
    }
}

(二)注冊並使用UDAF

package SparkUDAF;
 
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
 
import java.util.ArrayList;
import java.util.List;
 
public class TestMain {
    public static void main(String[] args) {
        SparkConf conf =new SparkConf();
        conf.setMaster("local").setAppName("MyAvg");
        JavaSparkContext sc= new JavaSparkContext(conf);
        //得到SQLContext對象
        SQLContext sqlContext = new SQLContext(sc);
 
        //注冊自定義函數
        sqlContext.udf().register("my_avg",new MyAvg());
 
        //讀入數據
        JavaRDD<String> lines = sc.textFile("d:\\test.txt");
        //分詞
        JavaRDD<Row> rows=lines.map(line-> RowFactory.create(line.split("\\^")));
 
        //定義schema的結構,a字段是字母,b字段是value
        List<StructField> structFields = new ArrayList<>();
        structFields.add(DataTypes.createStructField("a",DataTypes.StringType,true));
        structFields.add(DataTypes.createStructField("b",DataTypes.StringType,true));
        StructType structType = DataTypes.createStructType(structFields);
 
        //創建DataFrame
        Dataset ds=sqlContext.createDataFrame(rows,structType);
        ds.registerTempTable("test");
 
        //執行查詢
        sqlContext.sql("select a,my_avg(b) from test group by a").show();
        sc.stop();
    }
}

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM