卷首語
前一篇文章hive UDAF開發入門和運行過程詳解(轉)里面講過UDAF的開發過程,其中說到如果要深入理解UDAF的執行,可以看看求平均值的UDF的源碼
本人在看完源碼后,也還是沒能十分理解里面的內容,於是動手再自己開發一個新的函數,試圖多實踐中理解它
函數功能介紹
函數的功能比較蛋疼,我們都知道Hive中有幾個常用的聚合函數:sum,max,min,avg
現在要用一個函數來同時實現倆個不同的功能,對於同一個key,要求返回指定value集合中的最大值與最小值
這里面涉及到一個難點,函數接收到的數據只有一個,但是要同時產生出倆個新的數據出來,且具備一定的邏輯關系
語言描述這東西我不大懂,想了好久,還是直接上代碼得了。。。。。。。。。。。。。
源碼
package org.juefan.udaf; import java.util.ArrayList; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.util.StringUtils; /** * GenericUDAFMaxMin. */ @Description(name = "maxmin", value = "_FUNC_(x) - Returns the max and min value of a set of numbers") public class GenericUDAFMaxMin extends AbstractGenericUDAFResolver { static final Log LOG = LogFactory.getLog(GenericUDAFMaxMin.class.getName()); @Override public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException { if (parameters.length != 1) { throw new UDFArgumentTypeException(parameters.length - 1, "Exactly one argument is expected."); } if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) { throw new UDFArgumentTypeException(0, "Only primitive type arguments are accepted but " + parameters[0].getTypeName() + " is passed."); } switch (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory()) { case BYTE: case SHORT: case INT: case LONG: case FLOAT: case DOUBLE: case STRING: case TIMESTAMP: return new GenericUDAFMaxMinEvaluator(); case BOOLEAN: default: throw new UDFArgumentTypeException(0, "Only numeric or string type arguments are accepted but " + parameters[0].getTypeName() + " is passed."); } } /** * GenericUDAFMaxMinEvaluator. * */ public static class GenericUDAFMaxMinEvaluator extends GenericUDAFEvaluator { // For PARTIAL1 and COMPLETE PrimitiveObjectInspector inputOI; // For PARTIAL2 and FINAL StructObjectInspector soi; // 封裝好的序列化數據接口,存儲計算過程中的最大值與最小值 StructField maxField; StructField minField; // 存儲數據,利用get()可直接返回double類型值 DoubleObjectInspector maxFieldOI; DoubleObjectInspector minFieldOI; // For PARTIAL1 and PARTIAL2 // 存儲中間的結果 Object[] partialResult; // For FINAL and COMPLETE // 最終輸出的數據 Text result; @Override public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { assert (parameters.length == 1); super.init(m, parameters); // 初始化數據輸入過程 if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) { inputOI = (PrimitiveObjectInspector) parameters[0]; } else { // 如果接收到的數據是中間數據,則轉換成相應的結構體 soi = (StructObjectInspector) parameters[0]; // 獲取指定字段的序列化數據 maxField = soi.getStructFieldRef("max"); minField = soi.getStructFieldRef("min"); // 獲取指定字段的實際數據 maxFieldOI = (DoubleObjectInspector) maxField.getFieldObjectInspector(); minFieldOI = (DoubleObjectInspector) minField.getFieldObjectInspector(); } // 初始化數據輸出過程 if (m == Mode.PARTIAL1 || m == Mode.PARTIAL2) { // 輸出的數據是一個結構體,其中包含了max和min的值 // 存儲結構化數據類型 ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>(); foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); // 存儲結構化數據的字段名稱 ArrayList<String> fname = new ArrayList<String>(); fname.add("max"); fname.add("min"); partialResult = new Object[2]; partialResult[0] = new DoubleWritable(0); partialResult[1] = new DoubleWritable(0); return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi); } else { // 如果執行到了最后一步,則指定相應的輸出數據類型 result = new Text(""); return PrimitiveObjectInspectorFactory.writableStringObjectInspector; } } static class AverageAgg implements AggregationBuffer { double max; double min; }; @Override public AggregationBuffer getNewAggregationBuffer() throws HiveException { AverageAgg result = new AverageAgg(); reset(result); return result; } @Override public void reset(AggregationBuffer agg) throws HiveException { AverageAgg myagg = (AverageAgg) agg; myagg.max = Double.MIN_VALUE; myagg.min = Double.MAX_VALUE; } boolean warned = false; @Override public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { assert (parameters.length == 1); Object p = parameters[0]; if (p != null) { AverageAgg myagg = (AverageAgg) agg; try { // 獲取輸入數據,並進行相應的大小判斷 double v = PrimitiveObjectInspectorUtils.getDouble(p, inputOI); if(myagg.max < v){ myagg.max = v; } if(myagg.min > v){ myagg.min = v; } } catch (NumberFormatException e) { if (!warned) { warned = true; LOG.warn(getClass().getSimpleName() + " " + StringUtils.stringifyException(e)); LOG.warn(getClass().getSimpleName() + " ignoring similar exceptions."); } } } } @Override public Object terminatePartial(AggregationBuffer agg) throws HiveException { // 將中間計算出的結果封裝好返回給下一步操作 AverageAgg myagg = (AverageAgg) agg; ((DoubleWritable) partialResult[0]).set(myagg.max); ((DoubleWritable) partialResult[1]).set(myagg.min); return partialResult; } @Override public void merge(AggregationBuffer agg, Object partial) throws HiveException { if (partial != null) { //此處partial接收到的是terminatePartial的輸出數據 AverageAgg myagg = (AverageAgg) agg; Object partialmax = soi.getStructFieldData(partial, maxField); Object partialmin = soi.getStructFieldData(partial, minField); if(myagg.max < maxFieldOI.get(partialmax)){ myagg.max = maxFieldOI.get(partialmax); } if(myagg.min > minFieldOI.get(partialmin)){ myagg.min = minFieldOI.get(partialmin); } } } @Override public Object terminate(AggregationBuffer agg) throws HiveException { // 將最終的結果合並成字符串后輸出 AverageAgg myagg = (AverageAgg) agg; if (myagg.max == 0) { return null; } else { result.set(myagg.max + "\t" + myagg.min); return result; } } } }
寫完后還是覺得沒有怎么理解透整個過程,所以上面的注釋也就將就着看了,不保證一定正確的!
下午加上一些輸出跟蹤一下執行過程才行,不過代碼的邏輯是沒有問題的了,本人運行過!