在Spark中,也支持Hive中的自定義函數。自定義函數大致可以分為三種:
- UDF(User-Defined-Function),即最基本的自定義函數,類似to_char,to_date等
- UDAF(User- Defined Aggregation Funcation),用戶自定義聚合函數,類似在group by之后使用的sum,avg等
- UDTF(User-Defined Table-Generating Functions),用戶自定義生成函數,有點像stream里面的flatMap
本篇就手把手教你如何編寫UDF和UDAF
先來個簡單的UDF
場景:
我們有這樣一個文本文件:
1^^d
2^b^d
3^c^d
4^^d
在讀取數據的時候,第二列的數據如果為空,需要顯示'null'
,不為空就直接輸出它的值。定義完成后,就可以直接在SparkSQL中使用了。
代碼為:
package test;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.ArrayList;
import java.util.List;
/**
* Created by xinghailong on 2017/2/23.
*/
public class test3 {
public static void main(String[] args) {
//創建spark的運行環境
SparkConf sparkConf = new SparkConf();
sparkConf.setMaster("local[2]");
sparkConf.setAppName("test-udf");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
SQLContext sqlContext = new SQLContext(sc);
//注冊自定義方法
sqlContext.udf().register("isNull", (String field,String defaultValue)->field==null?defaultValue:field, DataTypes.StringType);
//讀取文件
JavaRDD<String> lines = sc.textFile( "C:\\test-udf.txt" );
JavaRDD<Row> rows = lines.map(line-> RowFactory.create(line.split("\\^")));
List<StructField> structFields = new ArrayList<StructField>();
structFields.add(DataTypes.createStructField( "a", DataTypes.StringType, true ));
structFields.add(DataTypes.createStructField( "b", DataTypes.StringType, true ));
structFields.add(DataTypes.createStructField( "c", DataTypes.StringType, true ));
StructType structType = DataTypes.createStructType( structFields );
DataFrame test = sqlContext.createDataFrame( rows, structType);
test.registerTempTable("test");
sqlContext.sql("SELECT con_join(c,b) FROM test GROUP BY a").show();
sc.stop();
}
}
輸出內容為:
+---+----+---+
| a| _c1| c|
+---+----+---+
| 1|null| d|
| 2| b| d|
| 3| c| d|
| 4|null| d|
+---+----+---+
其中比較關鍵的就是這句:
sqlContext.udf().register("isNull", (String field,String defaultValue)->field==null?defaultValue:field, DataTypes.StringType);
這里我直接用的java8的語法寫的,如果是java8之前的版本,需要使用Function2創建匿名函數。
再來個自定義的UDAF—求平均數
先來個最簡單的UDAF,求平均數。類似這種的操作有很多,比如最大值,最小值,累加,拼接等等,都可以采用相同的思路來做。
首先是需要定義UDAF函數
package test;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.ArrayList;
import java.util.List;
/**
* Created by xinghailong on 2017/2/23.
*/
public class MyAvg extends UserDefinedAggregateFunction {
@Override
public StructType inputSchema() {
List<StructField> structFields = new ArrayList<>();
structFields.add(DataTypes.createStructField( "field1", DataTypes.StringType, true ));
return DataTypes.createStructType( structFields );
}
@Override
public StructType bufferSchema() {
List<StructField> structFields = new ArrayList<>();
structFields.add(DataTypes.createStructField( "field1", DataTypes.IntegerType, true ));
structFields.add(DataTypes.createStructField( "field2", DataTypes.IntegerType, true ));
return DataTypes.createStructType( structFields );
}
@Override
public DataType dataType() {
return DataTypes.IntegerType;
}
@Override
public boolean deterministic() {
return false;
}
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0,0);
buffer.update(1,0);
}
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
buffer.update(0,buffer.getInt(0)+1);
buffer.update(1,buffer.getInt(1)+Integer.valueOf(input.getString(0)));
}
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
buffer1.update(0,buffer1.getInt(0)+buffer2.getInt(0));
buffer1.update(1,buffer1.getInt(1)+buffer2.getInt(1));
}
@Override
public Object evaluate(Row buffer) {
return buffer.getInt(1)/buffer.getInt(0);
}
}
使用的時候,需要先注冊,然后在spark sql里面就可以直接使用了:
package test;
import com.tgou.standford.misdw.udf.MyAvg;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.ArrayList;
import java.util.List;
/**
* Created by xinghailong on 2017/2/23.
*/
public class test4 {
public static void main(String[] args) {
SparkConf sparkConf = new SparkConf();
sparkConf.setMaster("local[2]");
sparkConf.setAppName("test");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
SQLContext sqlContext = new SQLContext(sc);
sqlContext.udf().register("my_avg",new MyAvg());
JavaRDD<String> lines = sc.textFile( "C:\\test4.txt" );
JavaRDD<Row> rows = lines.map(line-> RowFactory.create(line.split("\\^")));
List<StructField> structFields = new ArrayList<StructField>();
structFields.add(DataTypes.createStructField( "a", DataTypes.StringType, true ));
structFields.add(DataTypes.createStructField( "b", DataTypes.StringType, true ));
StructType structType = DataTypes.createStructType( structFields );
DataFrame test = sqlContext.createDataFrame( rows, structType);
test.registerTempTable("test");
sqlContext.sql("SELECT my_avg(b) FROM test GROUP BY a").show();
sc.stop();
}
}
計算的文本內容為:
a^3
a^6
b^2
b^4
b^6
再來個無所不能的UDAF
真正的業務場景里面,總會有千奇百怪的需求,比如:
- 想要按照某個字段分組,取其中的一個最大值
- 想要按照某個字段分組,對分組內容的數據按照特定字段統計累加
- 想要按照某個字段分組,針對特定的條件,拼接字符串
再比如一個場景,需要按照某個字段分組,然后分組內的數據,又需要按照某一列進行去重,最后再計算值
- 1 按照某個字段分組
- 2 分組校驗條件
- 3 然后處理字段
如果不用UDAF,你要是寫spark可能需要這樣做:
rdd.groupBy(r->r.xxx)
.map(t2->{
HashSet<String> set = new HashSet<>();
for(Object p : t2._2){
if(p.getBs() > 0 ){
map.put(xx,yyy)
}
}
return StringUtils.join(set.toArray(),",");
});
上面是一段偽碼,不保證正常運行哈。
這樣寫,其實也能應付需求了,但是代碼顯得略有點丑陋。還是不如SparkSQL看的清晰明了...
所以我們再嘗試用SparkSql中的UDAF來一版!
首先需要創建UDAF類
import org.apache.commons.lang.StringUtils;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.*;
import java.util.*;
/**
*
* Created by xinghailong on 2017/2/23.
*/
public class ConditionJoinUDAF extends UserDefinedAggregateFunction {
@Override
public StructType inputSchema() {
List<StructField> structFields = new ArrayList<>();
structFields.add(DataTypes.createStructField( "field1", DataTypes.IntegerType, true ));
structFields.add(DataTypes.createStructField( "field2", DataTypes.StringType, true ));
return DataTypes.createStructType( structFields );
}
@Override
public StructType bufferSchema() {
List<StructField> structFields = new ArrayList<>();
structFields.add(DataTypes.createStructField( "field", DataTypes.StringType, true ));
return DataTypes.createStructType( structFields );
}
@Override
public DataType dataType() {
return DataTypes.StringType;
}
@Override
public boolean deterministic() {//是否強制每次執行的結果相同
return false;
}
@Override
public void initialize(MutableAggregationBuffer buffer) {//初始化
buffer.update(0,"");
}
@Override
public void update(MutableAggregationBuffer buffer, Row input) {//相同的executor間的數據合並
Integer bs = input.getInt(0);
String field = buffer.getString(0);
String in = input.getString(1);
if(bs > 0 && !"".equals(in) && !field.contains(in)){
field += ","+in;
}
buffer.update(0,field);
}
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {//不同excutor間的數據合並
String field1 = buffer1.getString(0);
String field2 = buffer2.getString(0);
if(!"".equals(field2)){
field1 += ","+field2;
}
buffer1.update(0,field1);
}
@Override
public Object evaluate(Row buffer) {//根據Buffer計算結果
return StringUtils.join(Arrays.stream(buffer.getString(0).split(",")).filter(line->!line.equals("")).toArray(),",");
}
}
拿一個例子坐下實驗:
a^1111^2
a^1111^2
a^1111^2
a^1111^2
a^1111^2
a^2222^0
a^3333^1
b^4444^0
b^5555^3
c^6666^0
按照第一列進行分組,不同的第三列值,進行拼接。
package test;
import test.ConditionJoinUDAF;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.ArrayList;
import java.util.List;
/**
* Created by xinghailong on 2017/2/23.
*/
public class test2 {
public static void main(String[] args) {
SparkConf sparkConf = new SparkConf();
sparkConf.setMaster("local[2]");
sparkConf.setAppName("test");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
SQLContext sqlContext = new SQLContext(sc);
sqlContext.udf().register("con_join",new ConditionJoinUDAF());
JavaRDD<String> lines = sc.textFile( "C:\\test2.txt" );
JavaRDD<Row> rows = lines.map(line-> RowFactory.create(line.split("\\^")));
List<StructField> structFields = new ArrayList<StructField>();
structFields.add(DataTypes.createStructField( "a", DataTypes.StringType, true ));
structFields.add(DataTypes.createStructField( "b", DataTypes.StringType, true ));
structFields.add(DataTypes.createStructField( "c", DataTypes.StringType, true ));
StructType structType = DataTypes.createStructType( structFields );
DataFrame test = sqlContext.createDataFrame( rows, structType);
test.registerTempTable("test");
sqlContext.sql("SELECT con_join(c,b) FROM test GROUP BY a").show();
sc.stop();
}
}
這樣SQL簡潔明了,就能表達意思了。