RDD沒有可以這種可以注冊的方法。
在使用sparksql過程中發現UDF還是有點用的所以,還是單獨寫一篇博客記錄一下。
UDF=》一個輸入一個輸出。相當於map
UDAF=》多個輸入一個輸出。相當於reduce
UDTF=》一個輸入多個輸出。相當於flatMap。(需要hive環境,暫時未測試)
UDF
其實就是在sql語句中注冊函數,不要想得太難了。給大家寫一個case when的語句
import java.util.Arrays
import org.apache.spark.SparkConf
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.sql.{ DataFrame, Row, SparkSession, functions }
import org.apache.spark.sql.functions.{ col, desc, length, row_number, trim, when }
import org.apache.spark.sql.functions.{ countDistinct, sum, count, avg }
import org.apache.spark.sql.functions.concat
import org.apache.spark.sql.types.{ LongType, StringType, StructField, StructType }
import org.apache.spark.sql.expressions.Window
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.SaveMode
import java.util.ArrayList
object WordCount {
def main(args: Array[String]): Unit = {
val sparkSession = SparkSession.builder().master("local").getOrCreate()
val javasc = new JavaSparkContext(sparkSession.sparkContext)
val nameRDD1 = javasc.parallelize(Arrays.asList("{'id':'7'}", "{'id':'8'}",
"{'id':'9'}","{'id':'10'}"));
val nameRDD1df = sparkSession.read.json(nameRDD1)
nameRDD1df.createTempView("idList")
sparkSession.udf.register("idParse",(str:String)=>{//注冊一個函數,實現case when的函數
str match{
case "7" => "id7"
case "8" => "id8"
case "9" => "id9"
case _=>"others"
}
})
val data = sparkSession.sql("select idParse(id) from idList").show(100)
}
}
以上是UDF的sql用法,下面介紹data frame用法
import java.util.Arrays
import org.apache.spark.SparkConf
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.sql.{ DataFrame, Row, SparkSession, functions }
import org.apache.spark.sql.functions.{ col,udf, desc, length, row_number, trim, when }
import org.apache.spark.sql.functions.{ countDistinct, sum, count, avg }
import org.apache.spark.sql.functions.concat
import org.apache.spark.sql.types.{ LongType, StringType, StructField, StructType }
import org.apache.spark.sql.expressions.Window
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.SaveMode
import java.util.ArrayList
object WordCount {
def myUdf(value:String): String ={
println(value)
value+"|"
}
def main(args: Array[String]): Unit = {
val sparkSession = SparkSession.builder().master("local").getOrCreate()
val javasc = new JavaSparkContext(sparkSession.sparkContext)
val nameRDD1 = javasc.parallelize(Arrays.asList("{'id':'7'}", "{'id':'8'}","{'id':'9'}","{'id':'10'}"));
val fun = udf(myUdf _ )
val nameRDD1df = sparkSession.read.json(nameRDD1)
.select(fun(col("id")) as "id").show(100)
}
}
UDAF
import java.util.Arrays
import org.apache.spark.SparkConf
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.sql.{ DataFrame, Row, SparkSession, functions }
import org.apache.spark.sql.functions.{ col, desc, length, row_number, trim, when }
import org.apache.spark.sql.functions.{ countDistinct, sum, count, avg }
import org.apache.spark.sql.functions.concat
import org.apache.spark.sql.types.{ LongType, StringType, StructField, StructType }
import org.apache.spark.sql.expressions.Window
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.SaveMode
import java.util.ArrayList
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.DataType
class MyMax extends UserDefinedAggregateFunction{
//定義輸入數據的類型,兩種寫法都可以
//override def inputSchema: StructType = StructType(Array(StructField("input", IntegerType, true)))
override def inputSchema: StructType = StructType(StructField("input", IntegerType) :: Nil)
//定義聚合過程中所處理的數據類型
// override def bufferSchema: StructType = StructType(Array(StructField("cache", IntegerType, true)))
override def bufferSchema: StructType = StructType(StructField("max", IntegerType) :: Nil)
//定義輸入數據的類型
override def dataType: DataType = IntegerType
//規定一致性
override def deterministic: Boolean = true
//在聚合之前,每組數據的初始化操作
override def initialize(buffer: MutableAggregationBuffer): Unit = {buffer(0) =0}
//每組數據中,當新的值進來的時候,如何進行聚合值的計算
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if(input.getInt(0)> buffer.getInt(0))
buffer(0)=input.getInt(0)
}
//合並各個分組的結果
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
if(buffer2.getInt(0)> buffer1.getInt(0)){
buffer1(0)=buffer2.getInt(0)
}
}
//返回最終結果
override def evaluate(buffer: Row): Any = {buffer.getInt(0)}
}
class MyAvg extends UserDefinedAggregateFunction{
//輸入數據的類型
override def inputSchema: StructType = StructType(StructField("input", IntegerType) :: Nil)
//中間結果數據的類型
override def bufferSchema: StructType = StructType(
StructField("sum", IntegerType) :: StructField("count", IntegerType) :: Nil)
//定義輸入數據的類型
override def dataType: DataType = IntegerType
//規定一致性
override def deterministic: Boolean = true
//初始化操作
override def initialize(buffer: MutableAggregationBuffer): Unit = {buffer(0) =0;buffer(1) =0;}
//map端reduce,所有數據必須過這一段代碼
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0, buffer.getInt(0)+input.getInt(0))
buffer.update(1, buffer.getInt(1)+1)
}
//reduce數據,update里面Row,沒有第二個字段,這時候就有了第二個字段
override def merge(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0, buffer.getInt(0)+input.getInt(0))
buffer.update(1, buffer.getInt(1)+input.getInt(1))
}
//返回最終結果
override def evaluate(finalVaue: Row): Int = {finalVaue.getInt(0)/finalVaue.getInt(1)}
}
object WordCount {
def main(args: Array[String]): Unit = {
val sparkSession = SparkSession.builder().master("local").getOrCreate()
val javasc = new JavaSparkContext(sparkSession.sparkContext)
val nameRDD1 = javasc.parallelize(Arrays.asList("{'id':'7'}"));
val nameRDD1df = sparkSession.read.json(nameRDD1)
val nameRDD2 = javasc.parallelize(Arrays.asList( "{'id':'8'}"));
val nameRDD2df = sparkSession.read.json(nameRDD2)
val nameRDD3 = javasc.parallelize(Arrays.asList("{'id':'9'}"));
val nameRDD3df = sparkSession.read.json(nameRDD3)
val nameRDD4 = javasc.parallelize(Arrays.asList("{'id':'10'}"));
val nameRDD4df = sparkSession.read.json(nameRDD4)
nameRDD1df.union(nameRDD2df).union(nameRDD3df).union(nameRDD4df).registerTempTable("idList")
// sparkSession.udf.register("myMax",new MyMax)
sparkSession.udf.register("myAvg",new MyAvg)
val data = sparkSession.sql("select myAvg(id) from idList").show(100)
}
}
UDTF 暫時沒測試,家里沒有hive環境
import java.util.Arrays
import org.apache.spark.SparkConf
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.sql.{ DataFrame, Row, SparkSession, functions }
import org.apache.spark.sql.functions.{ col, desc, length, row_number, trim, when }
import org.apache.spark.sql.functions.{ countDistinct, sum, count, avg }
import org.apache.spark.sql.functions.concat
import org.apache.spark.sql.types.{ LongType, StringType, StructField, StructType }
import org.apache.spark.sql.expressions.Window
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.SaveMode
import java.util.ArrayList
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.DataType
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory
import org.apache.hadoop.hive.ql.exec.UDFArgumentException
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
class MyFloatMap extends GenericUDTF{
override def close(): Unit = {}
//這個方法的作用:1.輸入參數校驗 2. 輸出列定義,可以多於1列,相當於可以生成多行多列數據
override def initialize(args:Array[ObjectInspector]): StructObjectInspector = {
if (args.length != 1) {
throw new UDFArgumentLengthException("UserDefinedUDTF takes only one argument")
}
if (args(0).getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentException("UserDefinedUDTF takes string as a parameter")
}
val fieldNames = new java.util.ArrayList[String]
val fieldOIs = new java.util.ArrayList[ObjectInspector]
//這里定義的是輸出列默認字段名稱
fieldNames.add("col1")
//這里定義的是輸出列字段類型
fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
}
//這是處理數據的方法,入參數組里只有1行數據,即每次調用process方法只處理一行數據
override def process(args: Array[AnyRef]): Unit = {
//將字符串切分成單個字符的數組
val strLst = args(0).toString.split("")
for(i <- strLst){
var tmp:Array[String] = new Array[String](1)
tmp(0) = i
//調用forward方法,必須傳字符串數組,即使只有一個元素
forward(tmp)
}
}
}
object WordCount {
def main(args: Array[String]): Unit = {
val sparkSession = SparkSession.builder().master("local").getOrCreate()
val javasc = new JavaSparkContext(sparkSession.sparkContext)
val nameRDD1 = javasc.parallelize(Arrays.asList("{'id':'7'}"));
val nameRDD1df = sparkSession.read.json(nameRDD1)
nameRDD1df.createOrReplaceTempView("idList")
sparkSession.sql("create temporary function myFloatMap as 'MyFloatMap'")
val data = sparkSession.sql("select myFloatMap(id) from idList").show(100)
}
}
