Hive UDAF開發之同時計算最大值與最小值


卷首語

前一篇文章hive UDAF開發入門和運行過程詳解(轉)里面講過UDAF的開發過程,其中說到如果要深入理解UDAF的執行,可以看看求平均值的UDF的源碼

本人在看完源碼后,也還是沒能十分理解里面的內容,於是動手再自己開發一個新的函數,試圖多實踐中理解它

 

函數功能介紹

函數的功能比較蛋疼,我們都知道Hive中有幾個常用的聚合函數:sum,max,min,avg

現在要用一個函數來同時實現倆個不同的功能,對於同一個key,要求返回指定value集合中的最大值與最小值

這里面涉及到一個難點,函數接收到的數據只有一個,但是要同時產生出倆個新的數據出來,且具備一定的邏輯關系

語言描述這東西我不大懂,想了好久,還是直接上代碼得了。。。。。。。。。。。。。

 

源碼

 

package org.juefan.udaf;

import java.util.ArrayList;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.util.StringUtils;

/**
 * GenericUDAFMaxMin.
 */
@Description(name = "maxmin", value = "_FUNC_(x) - Returns the max and min value of a set of numbers")
public class GenericUDAFMaxMin extends AbstractGenericUDAFResolver {

    static final Log LOG = LogFactory.getLog(GenericUDAFMaxMin.class.getName());

    @Override
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
            throws SemanticException {
        if (parameters.length != 1) {
            throw new UDFArgumentTypeException(parameters.length - 1,
                    "Exactly one argument is expected.");
        }

        if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
            throw new UDFArgumentTypeException(0,
                    "Only primitive type arguments are accepted but "
                            + parameters[0].getTypeName() + " is passed.");
        }
        switch (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory()) {
        case BYTE:
        case SHORT:
        case INT:
        case LONG:
        case FLOAT:
        case DOUBLE:
        case STRING:
        case TIMESTAMP:
            return new GenericUDAFMaxMinEvaluator();
        case BOOLEAN:
        default:
            throw new UDFArgumentTypeException(0,
                    "Only numeric or string type arguments are accepted but "
                            + parameters[0].getTypeName() + " is passed.");
        }
    }

    /**
     * GenericUDAFMaxMinEvaluator.
     *
     */
    public static class GenericUDAFMaxMinEvaluator extends GenericUDAFEvaluator {

        // For PARTIAL1 and COMPLETE
        PrimitiveObjectInspector inputOI;

        // For PARTIAL2 and FINAL
        StructObjectInspector soi;
        // 封裝好的序列化數據接口,存儲計算過程中的最大值與最小值
        StructField maxField;
        StructField minField;
        // 存儲數據,利用get()可直接返回double類型值
        DoubleObjectInspector maxFieldOI;
        DoubleObjectInspector minFieldOI;

        // For PARTIAL1 and PARTIAL2
        // 存儲中間的結果
        Object[] partialResult;

        // For FINAL and COMPLETE
        // 最終輸出的數據
        Text result;

        @Override
        public ObjectInspector init(Mode m, ObjectInspector[] parameters)
                throws HiveException {
            assert (parameters.length == 1);
            super.init(m, parameters);

            // 初始化數據輸入過程
            if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
                inputOI = (PrimitiveObjectInspector) parameters[0];
            } else {
                // 如果接收到的數據是中間數據,則轉換成相應的結構體
                soi = (StructObjectInspector) parameters[0];
                // 獲取指定字段的序列化數據
                maxField = soi.getStructFieldRef("max");
                minField = soi.getStructFieldRef("min");
                // 獲取指定字段的實際數據
                maxFieldOI = (DoubleObjectInspector) maxField.getFieldObjectInspector();
                minFieldOI = (DoubleObjectInspector) minField.getFieldObjectInspector();
            }

            // 初始化數據輸出過程
            if (m == Mode.PARTIAL1 || m == Mode.PARTIAL2) {
                // 輸出的數據是一個結構體,其中包含了max和min的值
                // 存儲結構化數據類型
                ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>();
                foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
                foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
                // 存儲結構化數據的字段名稱
                ArrayList<String> fname = new ArrayList<String>();
                fname.add("max");
                fname.add("min");
                partialResult = new Object[2];
                partialResult[0] = new DoubleWritable(0);
                partialResult[1] = new DoubleWritable(0);
                return ObjectInspectorFactory.getStandardStructObjectInspector(fname,
                        foi);

            } else {
                // 如果執行到了最后一步,則指定相應的輸出數據類型
                result = new Text("");
                return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
            }
        }

        static class AverageAgg implements AggregationBuffer {
            double max;
            double min;
        };

        @Override
        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            AverageAgg result = new AverageAgg();
            reset(result);
            return result;
        }

        @Override
        public void reset(AggregationBuffer agg) throws HiveException {
            AverageAgg myagg = (AverageAgg) agg;
            myagg.max = Double.MIN_VALUE;
            myagg.min = Double.MAX_VALUE;
        }

        boolean warned = false;

        @Override
        public void iterate(AggregationBuffer agg, Object[] parameters)
                throws HiveException {
            assert (parameters.length == 1);
            Object p = parameters[0];
            if (p != null) {
                AverageAgg myagg = (AverageAgg) agg;
                try {
                    // 獲取輸入數據,並進行相應的大小判斷
                    double v = PrimitiveObjectInspectorUtils.getDouble(p, inputOI);
                    if(myagg.max < v){
                        myagg.max = v;
                    }
                    if(myagg.min > v){
                        myagg.min = v;
                    }
                } catch (NumberFormatException e) {
                    if (!warned) {
                        warned = true;
                        LOG.warn(getClass().getSimpleName() + " "
                                + StringUtils.stringifyException(e));
                        LOG.warn(getClass().getSimpleName()
                                + " ignoring similar exceptions.");
                    }
                }
            }
        }

        @Override
        public Object terminatePartial(AggregationBuffer agg) throws HiveException {
            // 將中間計算出的結果封裝好返回給下一步操作
            AverageAgg myagg = (AverageAgg) agg;
            ((DoubleWritable) partialResult[0]).set(myagg.max);
            ((DoubleWritable) partialResult[1]).set(myagg.min);
            return partialResult;
        }

        @Override
        public void merge(AggregationBuffer agg, Object partial)
                throws HiveException {
            if (partial != null) {
                //此處partial接收到的是terminatePartial的輸出數據
                AverageAgg myagg = (AverageAgg) agg;
                Object partialmax = soi.getStructFieldData(partial, maxField);
                Object partialmin = soi.getStructFieldData(partial, minField);
                if(myagg.max < maxFieldOI.get(partialmax)){
                    myagg.max = maxFieldOI.get(partialmax);
                }
                if(myagg.min > minFieldOI.get(partialmin)){
                    myagg.min = minFieldOI.get(partialmin);
                }
            }
        }

        @Override
        public Object terminate(AggregationBuffer agg) throws HiveException {
            // 將最終的結果合並成字符串后輸出
            AverageAgg myagg = (AverageAgg) agg;
            if (myagg.max == 0) {
                return null;
            } else {
                result.set(myagg.max + "\t" + myagg.min);
                return result;
            }
        }
    }

}

 

 

 

寫完后還是覺得沒有怎么理解透整個過程,所以上面的注釋也就將就着看了,不保證一定正確的!

下午加上一些輸出跟蹤一下執行過程才行,不過代碼的邏輯是沒有問題的了,本人運行過!


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM