Spark基於自定義聚合函數實現【列轉行、行轉列】


一.分析

  Spark提供了非常豐富的算子,可以實現大部分的邏輯處理,例如,要實現行轉列,可以用hiveContext中支持的concat_ws(',', collect_set('字段'))實現。但是這有明顯的局限性【sqlContext不支持】,因此,基於編碼邏輯或自定義聚合函數實現相同的邏輯就顯得非常重要了。

二.列轉行代碼實現 

 1 package utils
 2 import com.hankcs.hanlp.tokenizer.StandardTokenizer
 3 import org.apache.log4j.{Level, Logger}
 4 import org.apache.spark.sql.{SparkSession, Row}
 5 import org.apache.spark.sql.types.{StringType, StructType, StructField}
 6 /**
 7   * Created by Administrator on 2019/12/17.
 8   */
 9 object Column2Row {
10   /**
11     * 設置日志級別
12     */
13   Logger.getLogger("org").setLevel(Level.WARN)
14   def main(args: Array[String]) {
15     val spark = SparkSession.builder().master("local[2]").appName(s"${this.getClass.getSimpleName}").getOrCreate()
16     val sc = spark.sparkContext
17     val sqlContext = spark.sqlContext
18 
19     val array : Array[String] = Array("spark-高性能大數據解決方案", "spark-機器學習圖計算", "solr-搜索引擎應用廣泛", "solr-ES靈活高效")
20     val rdd = sc.parallelize(array)
21 
22     val termRdd = rdd.map(row => { // 標准分詞,掛載Hanlp分詞器
23     var result = ""
24       val type_content = row.split("-")
25       val termList = StandardTokenizer.segment(type_content(1))
26       for(i <- 0 until termList.size()){
27         val term = termList.get(i)
28         if(!term.nature.name.contains("w") && !term.nature.name().contains("u") && !term.nature.name().contains("m")){
29           if(term.word.length > 1){
30             result += term.word + " "
31           }
32         }
33       }
34       Row(type_content(0),result)
35     })
36 
37     val structType = StructType(Array(
38       StructField("arth_type", StringType, true),
39       StructField("content", StringType, true)
40     ))
41 
42     val termDF = sqlContext.createDataFrame(termRdd,structType)
43     termDF.show(false)
44     /**
45       * 列轉行
46       */
47     val termCheckDF = termDF.rdd.flatMap(row =>{
48       val arth_type = row.getAs[String]("arth_type")
49       val content = row.getAs[String]("content")
50       var res = Seq[Row]()
51       val content_array = content.split(" ")
52       for(con <- content_array){
53         res = res :+ Row(arth_type,con)
54       }
55       res
56     }).collect()
57 
58     val termListDF = sqlContext.createDataFrame(sc.parallelize(termCheckDF), structType)
59     termListDF.show(false)
60 
61     sc.stop()
62   }
63 }

三.列轉行執行結果

  列轉行之前:

  

  列轉行:

  

四.行轉列代碼實現

 1 package test
 2 
 3 import org.apache.log4j.{Level, Logger}
 4 import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
 5 import org.apache.spark.sql.types._
 6 import org.apache.spark.sql.{Row, SparkSession}
 7 
 8 /**
 9   * 自定義聚合函數實現行轉列
10   */
11 object AverageUserDefinedAggregateFunction extends UserDefinedAggregateFunction{
12   //聚合函數輸入數據結構
13   override def inputSchema:StructType = StructType(StructField("input", StringType) :: Nil)
14 
15   //緩存區數據結構
16   override def bufferSchema: StructType = StructType(StructField("result", StringType) :: Nil)
17 
18   //結果數據結構
19   override def dataType : DataType = StringType
20 
21   // 是否具有唯一性
22   override def deterministic : Boolean = true
23 
24   //初始化
25   override def initialize(buffer : MutableAggregationBuffer) : Unit = {
26     buffer(0) = ""
27   }
28 
29   //數據處理 : 必寫,其它方法可選,使用默認
30   override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
31     if(input.isNullAt(0)) return
32     if(buffer.getString(0) == null || buffer.getString(0).equals("")){
33       buffer(0) = input.getString(0) //拼接字符串
34     }else{
35       buffer(0) = buffer.getString(0) + "," + input.getString(0) //拼接字符串
36     }
37   }
38 
39   //合並
40   override def merge(bufferLeft: MutableAggregationBuffer, bufferRight: Row): Unit ={
41     if(bufferLeft(0) == null || bufferLeft(0).equals("")){
42       bufferLeft(0) = bufferRight.getString(0) //拼接字符串
43     }else{
44       bufferLeft(0) = bufferLeft(0) + "," + bufferRight.getString(0) //拼接字符串
45     }
46   }
47 
48   //計算結果
49   override def evaluate(buffer: Row): Any  = buffer.getString(0)
50 }
51 
52 /**
53   * Created by Administrator on 2019/12/17.
54   */
55 object Row2Columns {
56   /**
57     * 設置日志級別
58     */
59   Logger.getLogger("org").setLevel(Level.WARN)
60   def main(args: Array[String]): Unit = {
61     val spark = SparkSession.builder().master("local[2]").appName(s"${this.getClass.getSimpleName}").getOrCreate()
62     val sc = spark.sparkContext
63     val sqlContext = spark.sqlContext
64 
65     val array : Array[String] = Array("大數據-Spark","大數據-Hadoop","大數據-Flink","搜索引擎-Solr","搜索引擎-ES")
66 
67     val termRdd = sc.parallelize(array).map(row => { // 標准分詞,掛載Hanlp分詞器
68       val content = row.split("-")
69       Row(content(0), content(1))
70     })
71 
72     val structType = StructType(Array(
73       StructField("arth_type", StringType, true),
74       StructField("content", StringType, true)
75     ))
76 
77     val termDF = sqlContext.createDataFrame(termRdd,structType)
78     termDF.show()
79     termDF.createOrReplaceTempView("term")
80 
81     /**
82       * 注冊udaf
83       */
84     spark.udf.register("concat_ws", AverageUserDefinedAggregateFunction)
85     spark.sql("select arth_type,concat_ws(content) content from term group by arth_type").show()
86   }
87 }

五.行轉列執行結果

  行轉列之前:

  

  行轉列:

  


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM