在window10下安裝了hadoop,用ida創建maven項目。
<properties>
<spark.version>2.2.0</spark.version>
<scala.version>2.11</scala.version>
<java.version>1.8</java.version>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-yarn_${scala.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>8.0.16</version>
</dependency>
</dependencies>
<build>
<finalName>learnspark</finalName>
<plugins>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<version>3.2.2</version>
<executions>
<execution>
<goals>
<goal>compile</goal>
<goal>testCompile</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<version>3.0.0</version>
<configuration>
<archive>
<manifest>
<mainClass>learn</mainClass>
</manifest>
</archive>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
</configuration>
<executions>
<execution>
<id>make-assembly</id>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
數據准備:
{"name":"張3", "age":20}
{"name":"李4", "age":20}
{"name":"王5", "age":20}
{"name":"趙6", "age":20}
路徑:
data/input/user/user.json
程序:
package com.zouxxyy.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Encoder, Encoders, Row, SparkSession, TypedColumn}
/**
* UDF:用戶自定義函數
*/
object UDF {
def main(args: Array[String]): Unit = {
System.setProperty("hadoop.home.dir","D:\\gitworkplace\\winutils\\hadoop-2.7.1" )
//這個是用來指定我的hadoop路徑的,如果你的hadoop環境變量沒問題,可以不寫
val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("UDF")
// 創建SparkSession
val spark: SparkSession = SparkSession.builder.config(sparkConf).getOrCreate()
import spark.implicits._
// 從json中read得到的是DataFrame
val frame: DataFrame = spark.read.json("data/input/user/user.json")
frame.createOrReplaceTempView("user")
// 案例一:自定義一個簡單的函數測試
spark.udf.register("addName", (x:String)=> "Name:"+x)
spark.sql("select addName(name) from user").show()
// 案例二:自定義一個弱類型聚合函數測試
val udaf1 = new MyAgeAvgFunction
spark.udf.register("avgAge", udaf1)
spark.sql("select avgAge(age) from user").show()
// 案例三:自定義一個強類型聚合函數測試
val udaf2 = new MyAgeAvgClassFunction
// 將聚合函數轉換為查詢列
val avgCol: TypedColumn[UserBean, Double] = udaf2.toColumn.name("aveAge")
// 用強類型的Dataset的DSL風格的編程語法
val userDS: Dataset[UserBean] = frame.as[UserBean]
userDS.select(avgCol).show()
spark.stop()
}
}
/**
* 自定義內聚函數(弱類型)
*/
class MyAgeAvgFunction extends UserDefinedAggregateFunction{
// 輸入的數據結構
override def inputSchema: StructType = {
new StructType().add("age", LongType)
}
// 計算時的數據結構
override def bufferSchema: StructType = {
new StructType().add("sum", LongType).add("count", LongType)
}
// 函數返回的數據類型
override def dataType: DataType = DoubleType
// 函數是否穩定
override def deterministic: Boolean = true
// 計算前緩存區的初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
// 沒有名稱,只有結構
buffer(0) = 0L
buffer(1) = 0L
}
// 根據查詢結果,更新緩存區的數據
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
}
// 多個節點的緩存區的合並
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// 計算緩存區里的東西,得最終返回結果
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble / buffer.getLong(1)
}
}
/**
* 自定義內聚函數(強類型)
*/
case class UserBean (name : String, age : BigInt) // 文件讀取數字默認是BigInt
case class AvgBuffer(var sum: BigInt, var count: Int)
class MyAgeAvgClassFunction extends Aggregator[UserBean, AvgBuffer, Double] {
// 初始化緩存區
override def zero: AvgBuffer = {
AvgBuffer(0, 0)
}
// 輸入數據和緩存區計算
override def reduce(b: AvgBuffer, a: UserBean): AvgBuffer = {
b.sum = b.sum + a.age
b.count = b.count + 1
// 返回b
b
}
// 緩存區的合並
override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = {
b1.sum = b1.sum + b2.sum
b1.count = b1.count + b2.count
b1
}
// 計算返回值
override def finish(reduction: AvgBuffer): Double = {
reduction.sum.toDouble / reduction.count
}
override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
