自定義udaf函數的代碼框架
1 //首先繼承一個類AbstractGenericUDAFResolver,然后實現里面的getevaluate方法 2 public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {} 3 4 //在類里面再定義一個內部類繼承GenericUDAFEvaluator並重寫里面的幾個方法 5 6 public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException; 7 8 abstract AggregationBuffer getNewAggregationBuffer() throws HiveException; 9 10 public void reset(AggregationBuffer agg) throws HiveException; 11 12 public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException; 13 14 public Object terminatePartial(AggregationBuffer agg) throws HiveException; 15 16 public void merge(AggregationBuffer agg, Object partial) throws HiveException; 17 18 public Object terminate(AggregationBuffer agg) throws HiveException;
//方法的具體使用說明在實例代碼中說明
自己實現count聚合函數java代碼
public class Sum extends AbstractGenericUDAFResolver { //創建log對象,用於拋出錯誤和異常 static final Log log = LogFactory.getLog(Sum.class.getName()); //判斷sql語句傳入的參數的個數和類型,並將其返回相應的類型 @Override public GenericUDAFEvaluator getEvaluator(TypeInfo[] info) throws SemanticException { //判斷參數的個數是否符合要求 if (info.length != 1) { throw new UDFArgumentTypeException(info.length - 1, "exactly one parameter expected"); } //判斷傳入的參數類型 if (info[0].getCategory() != ObjectInspector.Category.PRIMITIVE) { throw new UDFArgumentTypeException(0, "only primitive argument is expected but " + info[0].getTypeName() + "is passed"); } //對傳入的參數類型進行進一步的判斷是否是我們需求的數據的類型 switch (((PrimitiveTypeInfo) info[0]).getPrimitiveCategory()) { case BYTE: case SHORT: case INT: case LONG: case FLOAT: case DOUBLE: return new SumRes(); default: throw new UDFArgumentTypeException(0, "only numric type is expected but " + info[0].getTypeName() + "is passed"); } } public static class SumRes extends GenericUDAFEvaluator { //創建變量存儲中間結果 //input:每一步執行時傳入的參數 //output:每一步執行時輸出的結果數據的類型 //input和output都只是指定的輸入輸出的數據類型而已,和數據計算本身無關 //result是聚合的結果的數據,和用於particial2和final階段的結果輸出,genuine不同的業務要求指定不同的類型等 private PrimitiveObjectInspector input; private PrimitiveObjectInspector output; private LongWritable result; //對各個階段都會首先調用一下該方法,並且對輸入輸出數據初始化 /** *Mode: * partial1 : map階段 會調用 init -> iterate -> partialterminate * partial2 : combiner階段 會調用 init -> merge -> partialterminate * final : reduce階段 會調用 init -> merge -> terminate * complete : 只有map沒有reduce階段 會調用 init -> iterate -> terminate */ @Override public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { assert parameters.length == 1; super.init(m,parameters); //init input //將傳入的參數賦值給定義的input輸入變量 if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) { input = (PrimitiveObjectInspector)parameters[0]; }else { input = (PrimitiveObjectInspector)parameters[0]; } //init output //返回中間聚合,或最終結果的數據的類型 if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) { output = PrimitiveObjectInspectorFactory.writableLongObjectInspector; }else { output = PrimitiveObjectInspectorFactory.writableLongObjectInspector; } //result用於實際接收聚合結果數據 result = new LongWritable(); return output; } //中間緩存的暫存結構,用於接收中間運行時需要暫存的變量數據 static class AggregateAgg implements AggregationBuffer{ Long sum; } @Override public AggregationBuffer getNewAggregationBuffer() throws HiveException { AggregateAgg result = new AggregateAgg(); reset(result); return result; } //刷新緩存重置暫存數據,重用jvm @Override public void reset(AggregationBuffer agg) throws HiveException { AggregateAgg myAgg = (AggregateAgg)agg; myAgg.sum = 0L; } //對map端傳入的每一條數據進行處理 @Override public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { assert parameters.length == 1; Object param = parameters[0]; if (param != null) { AggregateAgg myAgg = (AggregateAgg)agg; myAgg.sum ++; } } //返回map階段對每一條數據處理后的數據 @Override public Object terminatePartial(AggregationBuffer agg) throws HiveException { AggregateAgg myAgg = (AggregateAgg)agg; result.set(myAgg.sum); return result; } //在combiner和reduce時候回調用,對map輸出的結果進行聚合,即每一條數據調用一下,依次將數據累加到之前的結果上 @Override public void merge(AggregationBuffer agg, Object partial) throws HiveException { if (partial != null) { AggregateAgg myAgg = (AggregateAgg)agg; myAgg.sum += PrimitiveObjectInspectorUtils.getLong(partial,input); } } //使用變量接收最終的結果數據,並將數據進行返回 @Override public Object terminate(AggregationBuffer agg) throws HiveException { AggregateAgg myAgg = (AggregateAgg)agg; result.set(myAgg.sum); return result; } } }
