HIVE 用戶自定義函數UDAF實例(整合SUM+AVG帶注釋)


import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;

import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryStruct;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;

/** * TODO(round + avg 函數功能整合) * <p> * @author 忘塵 * @Date 2020年6月2日 */
public class GenericUDAFAveragePlus extends AbstractGenericUDAFResolver {

	// 重寫實現AbstractGenericUDAFResolver函數的執行器
	@Override
	public GenericUDAFEvaluator getEvaluator(TypeInfo[] info) throws SemanticException {
		// 驗證參數的有效性
		if (null != info && info.length == 1) {
			// 正常情況

			// 判斷是不是簡單類型
			if (info[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
				throw new UDFArgumentException("該函數該函數只能接收接收簡單類型的參數!");
			}
			// 判斷是不是Long類型
			// bigint -> long
			// 類型轉換
			PrimitiveTypeInfo pti = (PrimitiveTypeInfo) info[0];
			if (!pti.getPrimitiveCategory().equals(PrimitiveObjectInspector.PrimitiveCategory.LONG)) {
				throw new UDFArgumentException("該函數只能接收Long類型的參數");
			}
		} else {
			// 不正常情況
			throw new UDFArgumentException("該函數需要接收參數!並且只能傳遞一個參數!");
		}

		return new MyGenericUDAFEvaluator();

	}

	// 創建自己的執行器
	private static class MyGenericUDAFEvaluator extends GenericUDAFEvaluator {
		// 自定義我們自己的緩沖區類型 保存數據處理的臨時結果 
		private static class MyAggregationBuffer extends AbstractAggregationBuffer{
			// 定義緩沖區中存儲什么
			// 保存sum 和count
			private Double sum = 0D;
			private Long count = 0L;
			public Double getSum() {
				return sum;
			}
			public void setSum(Double sum) {
				this.sum = sum;
			}
			public Long getCount() {
				return count;
			}
			public void setCount(Long count) {
				this.count = count;
			}
			
		}
		// 創建緩沖區對象
		@Override
		public AggregationBuffer getNewAggregationBuffer() throws HiveException {
			printMode("getNewAggregationBuffer");
			return new MyAggregationBuffer();
		}
		// 初始化 參數校驗 返回值設置 
		// 一個階段調用一次
		@Override
		public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
			printMode("init");
			// 保留父類調用
			super.init(m, parameters);
			// 實現自己的
			// 根據不同的執行階段返回不同的數據
			// 需求1:mapper階段 包括map(PARTIAL1)和combiner(PARTIAL2) 需要返回sum+count->struct
			if (m == Mode.PARTIAL1 || m == Mode.PARTIAL2) {
				List<String> structFieldNames = new ArrayList<String>();
				List<ObjectInspector> structFieldObjectInspectors = new ArrayList<ObjectInspector>();
				// struct<sum:double,count:bigint>
				structFieldNames.add("sum");
				structFieldNames.add("count");
				structFieldObjectInspectors.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
				structFieldObjectInspectors.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
				// 返回
				return ObjectInspectorFactory.getStandardStructObjectInspector(structFieldNames, structFieldObjectInspectors);
			}else {
				// reduce階段
				return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
			}
			
		}
		// AggregationBuffer是聚合函數緩沖區對象 貫穿於 聚合函數始終的一個數據傳輸對象
		// 擦寫緩沖區 讓緩沖區重復使用
		@Override
		public void reset(AggregationBuffer agg) throws HiveException {
			printMode("reset");
			((MyAggregationBuffer)agg).setCount(0L);
			((MyAggregationBuffer)agg).setSum(0D);
		}
		
		private Long p = 0L;
// private Long history_count = 0L;
// private Double history_sum = 0D;
		
		private Long current_count = 0L;
		private Double current_sum = 0D;
		// Mapper類的map函數用於處理輸入數據即迭代局部數據的 
		@Override
		public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
			printMode("iterate");
			// parameters[0]為傳入的參數,每次一個 為bigint類型
			// 先將其轉為string類型,再轉為Long
			p = Long.parseLong(String.valueOf(parameters[0]).trim());
			// map的循環 將數據放入緩沖區
			// 將agg轉為自己的緩沖區
			MyAggregationBuffer ab = (MyAggregationBuffer) agg;
			// 從緩沖區中獲取之前存儲的數據
// history_count = ab.getCount();
// history_sum = ab.getSum();
			//進行本次循環操作
			current_sum += p;
			current_count++;
			// 保存本次數據
			ab.setCount(current_count);
			ab.setSum(current_sum);
			
		}
		
		// 定義一個結構進行數據的存儲
		private Object[] mapout = {new DoubleWritable(),new LongWritable()};
		
		// map的最終結果輸出方法 處理全部輸出數據中的部分數據 
		@Override
		public Object terminatePartial(AggregationBuffer agg) throws HiveException {
			printMode("terminatePartial");
			// 獲取map的最終輸出
			MyAggregationBuffer ab = (MyAggregationBuffer) agg;
			((DoubleWritable)mapout[0]).set(ab.getSum());
			((LongWritable)mapout[1]).set(ab.getCount());
			// 直接返回mapout
			return mapout;

		}
		// 進行 map 局部結果的全局化處理 Combiner 和 Reducer的reduce方法
		// partial來自terminatePartial的返回值
		@Override
		public void merge(AggregationBuffer agg, Object partial) throws HiveException {
			printMode("merge");
			// map結構通過網絡到partial 要將partial轉為結構
			if (partial instanceof LazyBinaryStruct) {
				// 強轉參數
				LazyBinaryStruct lbs = (LazyBinaryStruct) partial;
				
				DoubleWritable sum = (DoubleWritable) lbs.getField(0);
				LongWritable count = (LongWritable) lbs.getField(1);
				
				// 將本次map輸出的數據放到reducer的緩沖區
				MyAggregationBuffer ab = (MyAggregationBuffer) agg;
				ab.setCount(ab.getCount() + count.get());
				ab.setSum(ab.getSum() + sum.get());
			}

		}
		private Text reduceout = new Text();
		// Combiner 或 Reducer的最終輸出方法
		@Override
		public Object terminate(AggregationBuffer agg) throws HiveException {
			printMode("terminate");
			// 獲取reduce累加之后的最終結果
			MyAggregationBuffer ab = (MyAggregationBuffer) agg;
			Double sum = ab.getSum();
			Long count = ab.getCount();
			Double avg = sum/count;
			DecimalFormat df = new DecimalFormat("###,###.00");
			reduceout.set(df.format(avg));
			return reduceout;
		}
		// 打印個階段信息
        public void printMode(String mname){
            System.out.println("=================================== "+mname+" is Running! ================================");
        }

	}

}


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM