一、函數的源碼
/**
* Simplified version of combineByKeyWithClassTag that hash-partitions the resulting RDD using the
* existing partitioner/parallelism level. This method is here for backward compatibility. It
* does not provide combiner classtag information to the shuffle.
*
* @see `combineByKeyWithClassTag`
*/
def combineByKey[C](
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C): RDD[(K, C)] = self.withScope {
combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners)(null)
}
/**
* Generic function to combine the elements for each key using a custom set of aggregation
* functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a
* "combined type" C.
*
* Users provide three functions:
*
* - `createCombiner`, which turns a V into a C (e.g., creates a one-element list)
* - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
* - `mergeCombiners`, to combine two C's into a single one.
*
* In addition, users can control the partitioning of the output RDD, the serializer that is use
* for the shuffle, and whether to perform map-side aggregation (if a mapper can produce multiple
* items with the same key).
*
* @note V and C can be different -- for example, one might group an RDD of type (Int, Int) into
* an RDD of type (Int, List[Int]).
*/
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C],
partitioner: Partitioner,
mapSideCombine: Boolean,
serializer: Serializer): JavaPairRDD[K, C] = {
implicit val ctag: ClassTag[C] = fakeClassTag
fromRDD(rdd.combineByKeyWithClassTag(
createCombiner,
mergeValue,
mergeCombiners,
partitioner,
mapSideCombine,
serializer
))
}
由源碼中看出,該函數中主要包含的參數:
createCombiner:V=>C
mergeValue:(C,V)=>C
mergeCombiners:(C,C)=>R
partitioner:Partitioner
mapSideCombine:Boolean=true
serializer:Serializer=null
這里的每一個參數都對分別對應這聚合操作的各個階段
二、參數詳解:
1、createCombiner:V=>C 分組內的創建組合的函數。通俗點將就是對讀進來的數據進行初始化,其把當前的值作為參數,可以對該值做一些轉換操作,轉換為我們想要的數據格式
2、mergeValue:(C,V)=>C 該函數主要是分區內的合並函數,作用在每一個分區內部。其功能主要是將V合並到之前(createCombiner)的元素C上,注意,這里的C指的是上一函數轉換之后的數據格式,而這里的V指的是原始數據格式(上一函數為轉換之前的)
3、mergeCombiners:(C,C)=>R 該函數主要是進行多分取合並,此時是將兩個C合並為一個C,例如兩個C:(Int)進行相加之后得到一個R:(Int)
4、partitioner:自定義分區數,默認是hashPartitioner
5、mapSideCombine:Boolean=true 該參數是設置是否在map端進行combine操作
三、函數工作流程
首先明確該函數遍歷的數據是(k,v)對的rdd數據
1、combinByKey會遍歷rdd中每一個(k,v)數據對,對該數據對中的k進行判斷,判斷該(k,v)對中的k是否在之前出現過,如果是第一次出現,則調用createCombiner函數,對該k對應的v進行初始化操作(可以做一些轉換操作),也就是創建該k對應的累加其的初始值
2、如果這是一個在處理當前分區之前遇到的k,會調用mergeCombiners函數,把該k對應的累加器的value與這個新的value進行合並操作
四、實例解釋
實例1:
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import scala.collection.mutable
object test {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("testCombineByKey").setMaster("local[2]")
val ssc = new SparkSession
.Builder()
.appName("test")
.master("local[2]")
.config(conf)
.getOrCreate()
val sc = ssc.sparkContext
sc.setLogLevel("error")
val initialScores = Array((("1", "011"), 1), (("1", "012"), 1), (("2", "011"), 1), (("2", "013"), 1), (("2", "014"), 1))
val d1 = sc.parallelize(initialScores)
d1.map(x => (x._1._1, (x._1._2, 1)))
.combineByKey(
(v: (String, Int)) => (v: (String, Int)),
(acc: (String, Int), v: (String, Int)) => (v._1+":"+acc._1,acc._2+v._2),
(p1:(String,Int),p2:(String,Int)) => (p1._1 + ":" + p2._1,p1._2 + p2._2)
).collect().foreach(println)
}
}
從map函數開始說起:
1、map端將數據格式化為:(,(String,Int))->("1",("011",1))
2、接着combineByKye函數會逐個的讀取map之后的每一個k,v數據對,當讀取到第一個("1",("011",1)),此時回判斷,“1”這個是否在之前的出現過,如果該k是第一次出現,則會調用createCombiner函數,經過轉換,該實例中是對該value值么有做任何的改變原樣返回,此時這個該value對應的key回被comgbineByKey函數創建一個累加其記錄
3、當讀取到第二個數據("1",("012",,1))的時候,回對“1”這個key進行一個判斷,發現其在之前出現過,此時怎直接調用第二個函數,mergeValues函數,對應到該實例中,acc即為上一函數產生的結果,即("1",("011",1)),v即是新讀進來的數據("1",("012",1))
4、此時執行該函數:(acc: (String, Int), v: (String, Int)) => (v._1+":"+acc._1,acc._2+v._2)
將新v中的第一個字符串與acc中的第一個字符串進行連接,v中的第二個值,與acc中的第二個值進行相加操作
5、當所有的分區內的數據計算完成之后,開始調用mergeCombiners函數,對每個分區的數據進行合並,該實例中p1和p2分別對應的是不同分區的計算結果,所以二者的數據格式是完全相同的,此時將第一個分區中的字符串相連接,第二個字符相加得到最終結果
6、最終輸出結果為:
(2,(014:013:011,3))
(1,(012:011,2))
實例2
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import scala.collection.mutable
object test {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("testCombineByKey").setMaster("local[2]")
val ssc = new SparkSession
.Builder()
.appName("test")
.master("local[2]")
.config(conf)
.getOrCreate()
val sc = ssc.sparkContext
sc.setLogLevel("error")
val initialScores = Array((("1", "011"), 1), (("1", "012"), 1), (("2", "011"), 1), (("2", "013"), 1), (("2", "014"), 1))
val d1 = sc.parallelize(initialScores)
d1.map(x => (x._1._1, (x._1._2, 1)))
//("1",("011",1))
.combineByKey(
(v: (String, Int)) => (mutable.Set[String](v._1), (mutable.Set[Int](v._2))),
(acc: (mutable.Set[String], mutable.Set[Int]), v: (String, Int)) => (acc._1 + v._1, acc._2 + v._2),
(acc1: (mutable.Set[String], mutable.Set[Int]), acc2: (mutable.Set[String], mutable.Set[Int])) => (acc1._1 ++ acc2._1, acc1._2 ++ acc2._2)
).collect().foreach(println)
}
}
運行結果為:
(2,(Set(013, 014, 011),Set(1)))
(1,(Set(011, 012),Set(1)))
