hive自定義分段函數(分箱)


分段函數常用於分箱中,統計分組在指定的區間中的占比。

比如有如下例子:統計某個班級中考試分數在各個階段的占比。

准備的數據如下:

使用如下文件在hive中建表。

class1,1,100
class1,2,88
class1,3,90
class1,4,23
class1,5,30
class1,6,55
class1,7,66
class1,8,99
class1,9,56
class1,10,34

 

這時候使用case when來計算每行記錄分別在哪個區間如下:

with tmp_a as(
select 
clazz,name,
case when score <30 then '[0,30)'
when score <60 then '[30,60)'
when score < 80 then '[60,80)'
when score <= 100 then '[80,100]'
else 'none' end bins
from dt_dwd.score
)
select clazz,bins,count(1)/sum(count(1)) over (partition by clazz) as rate,count(1)
from tmp_a group by clazz,bins; 

最后是統計結果如下:

 

 

上述就是通常的分箱占比操作例子。


 

 

現在我有多組標簽需要監控,每次寫case when的,這里面的分段非常多,於是想到用hive udf來簡化寫法。

先看已經完成的自定義函數default.piecewise的sql寫法如下:

select clazz,name,default.piecewise('[0,30)|[30,60)|[60,80)|[80,100]',score) as bins
from dt_dwd.score

  

完整的sql如下:

with tmp_a as(
	select 
	clazz,name,default.piecewise('[0,30)|[30,60)|[60,80)|[80,100]',score) as bins
	from dt_dwd.score
)
select clazz,bins,count(1)/sum(count(1)) over (partition by clazz) as rate,count(1)
from tmp_a group by clazz,bins;

這樣我們可以將分箱抽象到變量中,在當做參數傳入,就不要每次寫很大段的case when了。

 

 default.piecewise的完整寫法如下:

package com.demo.udf;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaConstantStringObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.Text;

import java.util.Objects;

/**
 * @Author: KingWang
 * @Date: 2021/9/21
 * @Desc: 自定義分段函數
 *   傳入參數1:none|[0,30]|(30,60]|(60,90]|(90,+]
 *   傳入參數2:值
 *   返回:根據參數2的值,判斷在參數1的區間,返回參數1的區間值
 *   如: 參數2:45, 則返回(30,60]
 **/
@Description(name = "default.piecewise", value = "_FUNC_(piecewise, value) - Returns piecewise if the value mapped.", extended = "Example:\n  > SELECT _FUNC_('[0,30)|[30,60)|(60,100]', 88) FROM table limit 1;\n  '(60,100]'")
public class Piecewise extends GenericUDF {

    private transient StringObjectInspector piecewiseOI;
    private transient StringObjectInspector valOI;



    @Override
    public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
        if (arguments.length != 2) {
            throw new UDFArgumentException("The function piecewiseUDF accepts 2 arguments.");
        }
        if(null == arguments[0]){
            throw new UDFArgumentException("first arguments can not be null.");
        }

        this.piecewiseOI = (StringObjectInspector) arguments[0];
        this.valOI = (StringObjectInspector) (null == arguments[1] ? new JavaConstantStringObjectInspector("")  : arguments[1]);
        return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
    }

    @Override
    public Object evaluate(DeferredObject[] deferredObjects) throws HiveException {

        String piecewise = piecewiseOI.getPrimitiveJavaObject(deferredObjects[0].get());
        String val = valOI.getPrimitiveJavaObject(null != deferredObjects[1] ? deferredObjects[1].get():"");

        String[] list = piecewise.split("\\|");
        if(Objects.isNull(val)||"".equals(val)){
            return new Text("none");
        }
        boolean match = false;
        for(String str:list){
            if(str.indexOf(",")>0){
                try{
                    double value = Double.valueOf(val);
                    if(str.startsWith("(-,") || str.startsWith("[-,")){
                        match = minusAndValue(str,value);
                    }else if(str.endsWith(",+)") || str.endsWith(",+]")){
                        match = valueAndPlus(str,value);
                    }else{
                        match = valueAndValue(str,value);
                    }
                }catch (NumberFormatException e){

                }
            } else {
                if(str.equalsIgnoreCase(val)){
                    match = true;
                }
            }
            if(match) return new Text(str);
        }

        //未匹配上的返回到ERROR分組中
        return new Text("ERROR");
    }

    @Override
    public String getDisplayString(String[] strings) {
        return strings[0];
    }



    /**
     * 表達式類似於(-,60)或者[-,60)或者(-,60]或者[-,60]
     * @param express
     * @return
     */
    public static boolean minusAndValue(String express,Double value){
        boolean is_match = false;
        String endStr = express.split(",")[1];
        double end = Double.valueOf(endStr.substring(0,endStr.length()-1));
        if(express.endsWith(")")){
            if( value < end ){
                is_match = true;
            }
        }else if(express.endsWith("]")){
            if( value <= end){
                is_match = true;
            }
        }
        return is_match;
    }

    /**
     * 表達式類似於(80,+)或者[80,+)或者(80,+]或者[80,+]
     * @param express
     * @return
     */
    public static boolean valueAndPlus(String express,Double value){
        boolean is_match = false;
        String beginStr = express.split(",")[0];
        double begin = Double.valueOf(beginStr.substring(1));
        if(express.startsWith("(")){
            if( value > begin ){
                is_match = true;
            }
        }else if(express.startsWith("[")){
            if(value >= begin){
                is_match = true;
            }
        }
        return is_match;
    }

    /**
     * 表達式類似於(60,80)或者(60,80]或者[60,80)或者[60,80]
     * @param express
     * @return
     */
    public static boolean valueAndValue(String express,Double value){
        boolean is_match = false;
        String beginStr = express.split(",")[0];
        String endStr = express.split(",")[1];

        double begin = Double.valueOf(beginStr.substring(1));
        double end = Double.valueOf(endStr.substring(0,endStr.length()-1));
        if(express.startsWith("(") && express.endsWith(")")){
            if( value> begin && value < end){
                is_match = true;
            }
        }else if(express.startsWith("[") && express.endsWith("]")){
            if(value >= begin && value <= end){
                is_match = true;
            }
        }else if(express.startsWith("(") && express.endsWith("]")){
            if(value > begin && value <= end){
                is_match = true;
            }
        }else if(express.startsWith("[") && express.endsWith(")")){
            if(value >= begin && value < end){
                is_match = true;
            }
        }
        return is_match;
    }
}

 

 

實際使用中,shell腳本中定義分箱變量,通過參數傳遞給scrip.sql腳本

使用注意事項: 

當傳入的值為null時,會報異常,需要使用nvl(nullfileld,'') 或者nvl(nulllfiled,'none')來處理,其結果將默認分配在none分段中。 

 

 然后在script.sql腳本中可以接收傳入的分箱變量來靈活使用。


2021.09.26 優化:

將未匹配的行,使用ERROR分箱表示,因為執行程序過程中如果發現有不在給定的分箱里面的,會報錯,到時候排錯特別困難。

這樣直接給個帶ERROR的分箱組,可以很直接的在結果數據中可以觀察到。

  

  

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM