hive自定義udaf函數


自定義udaf函數的代碼框架

 1 //首先繼承一個類AbstractGenericUDAFResolver,然后實現里面的getevaluate方法
 2 public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {}
 3 
 4 //在類里面再定義一個內部類繼承GenericUDAFEvaluator並重寫里面的幾個方法
 5 
 6 public  ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException;
 7  
 8 abstract AggregationBuffer getNewAggregationBuffer() throws HiveException;
 9  
10 public void reset(AggregationBuffer agg) throws HiveException;
11  
12 public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException;
13  
14 public Object terminatePartial(AggregationBuffer agg) throws HiveException;
15  
16 public void merge(AggregationBuffer agg, Object partial) throws HiveException;
17  
18 public Object terminate(AggregationBuffer agg) throws HiveException;

//方法的具體使用說明在實例代碼中說明

自己實現count聚合函數java代碼

public class Sum extends AbstractGenericUDAFResolver {
    //創建log對象,用於拋出錯誤和異常
    static final Log log = LogFactory.getLog(Sum.class.getName());


    //判斷sql語句傳入的參數的個數和類型,並將其返回相應的類型
    @Override
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] info) throws SemanticException {
        //判斷參數的個數是否符合要求
        if (info.length != 1) {
            throw new UDFArgumentTypeException(info.length - 1, "exactly one parameter expected");
        }

        //判斷傳入的參數類型
        if (info[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
            throw new UDFArgumentTypeException(0, "only primitive argument is expected but " +
                    info[0].getTypeName() + "is passed");
        }

        //對傳入的參數類型進行進一步的判斷是否是我們需求的數據的類型
        switch (((PrimitiveTypeInfo) info[0]).getPrimitiveCategory()) {
            case BYTE:
            case SHORT:
            case INT:
            case LONG:
            case FLOAT:
            case DOUBLE:
                return new SumRes();
            default:
                throw new UDFArgumentTypeException(0, "only numric type is expected but " + info[0].getTypeName() + "is passed");
        }
    }


    public static class SumRes extends GenericUDAFEvaluator {

        //創建變量存儲中間結果
        //input:每一步執行時傳入的參數
        //output:每一步執行時輸出的結果數據的類型
        //input和output都只是指定的輸入輸出的數據類型而已,和數據計算本身無關
        //result是聚合的結果的數據,和用於particial2和final階段的結果輸出,genuine不同的業務要求指定不同的類型等
        private PrimitiveObjectInspector input;
        private PrimitiveObjectInspector output;
        private LongWritable result;

        //對各個階段都會首先調用一下該方法,並且對輸入輸出數據初始化

        /**
         *Mode:
         * partial1 : map階段                會調用 init -> iterate -> partialterminate
         * partial2 : combiner階段           會調用 init -> merge -> partialterminate
         * final    : reduce階段             會調用 init -> merge -> terminate
         * complete : 只有map沒有reduce階段   會調用 init -> iterate -> terminate
         */
        @Override
        public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
            assert parameters.length == 1;
            super.init(m,parameters);

            //init input
            //將傳入的參數賦值給定義的input輸入變量
            if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
                input = (PrimitiveObjectInspector)parameters[0];
            }else {
                input = (PrimitiveObjectInspector)parameters[0];
            }

            //init output
            //返回中間聚合,或最終結果的數據的類型
            if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
                output = PrimitiveObjectInspectorFactory.writableLongObjectInspector;
            }else {
                output = PrimitiveObjectInspectorFactory.writableLongObjectInspector;
            }
            //result用於實際接收聚合結果數據
            result = new LongWritable();
            return output;
        }


        //中間緩存的暫存結構,用於接收中間運行時需要暫存的變量數據
        static class AggregateAgg implements AggregationBuffer{
            Long sum;
        }
        @Override
        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            AggregateAgg result = new AggregateAgg();
            reset(result);
            return result;
        }

        //刷新緩存重置暫存數據,重用jvm
        @Override
        public void reset(AggregationBuffer agg) throws HiveException {
            AggregateAgg myAgg = (AggregateAgg)agg;
            myAgg.sum = 0L;
        }

        //對map端傳入的每一條數據進行處理
        @Override
        public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
            assert parameters.length == 1;
            Object param = parameters[0];
            if (param != null) {
                AggregateAgg myAgg = (AggregateAgg)agg;
                myAgg.sum ++;
            }
        }

        //返回map階段對每一條數據處理后的數據
        @Override
        public Object terminatePartial(AggregationBuffer agg) throws HiveException {
            AggregateAgg myAgg = (AggregateAgg)agg;
            result.set(myAgg.sum);
            return result;
        }

        //在combiner和reduce時候回調用,對map輸出的結果進行聚合,即每一條數據調用一下,依次將數據累加到之前的結果上
        @Override
        public void merge(AggregationBuffer agg, Object partial) throws HiveException {
            if (partial != null) {
                AggregateAgg myAgg = (AggregateAgg)agg;
                myAgg.sum += PrimitiveObjectInspectorUtils.getLong(partial,input);
            }
        }

        //使用變量接收最終的結果數據,並將數據進行返回
        @Override
        public Object terminate(AggregationBuffer agg) throws HiveException {
            AggregateAgg myAgg = (AggregateAgg)agg;
            result.set(myAgg.sum);
            return result;
        }
    }
}

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM