Spark SQL(8)-Spark sql聚合操作(Aggregation)
之前簡單總結了spark從sql到物理計划的整個流程,接下來就總結下Spark SQL中關於聚合的操作。
聚合操作的物理計划生成
首先從一條sql開始吧
SELECT NAME,COUNT(*) FRON PEOPLE GROUP BY NAME
這條sql的經過antlr4解析后的樹結構如下:

在解析出來的樹結構中可以看出來,在querySpecification下面多了aggregation子節點。這次我們只關注關於聚合的相關操作。在analyze的階段,關於聚合的解析是在AstBuilder.withQuerySpecification方法中:
private def withQuerySpecification(
ctx: QuerySpecificationContext,
relation: LogicalPlan): LogicalPlan = withOrigin(ctx) {
import ctx._
// 去掉了一些其他操作代碼。。。
// Add where.
val withFilter = withLateralView.optionalMap(where)(filter)
// Add aggregation or a project.
val namedExpressions = expressions.map {
case e: NamedExpression => e
case e: Expression => UnresolvedAlias(e)
}
val withProject = if (aggregation != null) {
withAggregation(aggregation, namedExpressions, withFilter)
} else if (namedExpressions.nonEmpty) {
Project(namedExpressions, withFilter)
} else {
withFilter
}
// Having
val withHaving = withProject.optional(having) {
// Note that we add a cast to non-predicate expressions. If the expression itself is
// already boolean, the optimizer will get rid of the unnecessary cast.
val predicate = expression(having) match {
case p: Predicate => p
case e => Cast(e, BooleanType)
}
Filter(predicate, withProject)
}
// Distinct
val withDistinct = if (setQuantifier() != null && setQuantifier().DISTINCT() != null) {
Distinct(withHaving)
} else {
withHaving
}
// Window
// Hint
}
}
如下為withAggregation方法:
private def withAggregation(
ctx: AggregationContext,
selectExpressions: Seq[NamedExpression],
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
val groupByExpressions = expressionList(ctx.groupingExpressions)
if (ctx.GROUPING != null) {
// GROUP BY .... GROUPING SETS (...)
val selectedGroupByExprs =
ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e)))
GroupingSets(selectedGroupByExprs, groupByExpressions, query, selectExpressions)
} else {
// GROUP BY .... (WITH CUBE | WITH ROLLUP)?
val mappedGroupByExpressions = if (ctx.CUBE != null) {
Seq(Cube(groupByExpressions))
} else if (ctx.ROLLUP != null) {
Seq(Rollup(groupByExpressions))
} else {
groupByExpressions
}
Aggregate(mappedGroupByExpressions, selectExpressions, query)
}
}
可以看出來最后在樹中添加了一個Aggregate節點。現在這里跳過優化的操作就是物理計划的處理,物理計划里面主要關注聚合相關的策略就是:
object Aggregation extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalAggregation(
groupingExpressions, aggregateExpressions, resultExpressions, child) =>
val (functionsWithDistinct, functionsWithoutDistinct) =
aggregateExpressions.partition(_.isDistinct)
if (functionsWithDistinct.map(_.aggregateFunction.children.toSet).distinct.length > 1) {
// This is a sanity check. We should not reach here when we have multiple distinct
// column sets. Our `RewriteDistinctAggregates` should take care this case.
sys.error("You hit a query analyzer bug. Please report your query to " +
"Spark user mailing list.")
}
val aggregateOperator =
if (functionsWithDistinct.isEmpty) {
aggregate.AggUtils.planAggregateWithoutDistinct(
groupingExpressions,
aggregateExpressions,
resultExpressions,
planLater(child))
} else {
aggregate.AggUtils.planAggregateWithOneDistinct(
groupingExpressions,
functionsWithDistinct,
functionsWithoutDistinct,
resultExpressions,
planLater(child))
}
aggregateOperator
case _ => Nil
}
}
從上面的邏輯可以看出來,這里根據函數里面有沒有包含distinct操作,分別調用planAggregateWithoutDistinct和planAggregateWithOneDistinct來生成物理計划。到此再經過准備階段,聚合操作的物理計划為生成也就結束了。
接下來會分析下planAggregateWithoutDistinct和planAggregateWithOneDistinct實現的不同,還有就是spark sql針對聚合操作的實現方式。在介紹這倆個之前首先介紹下聚合的模式(AggregateMode)
聚合的模式AggregateMode和聚合函數
Partial 主要是代表局部合並,對輸入的數據更新到聚合緩沖區,返回聚合緩沖區數據;
Final 將聚合緩沖區的數據進行合並,返回最終的結果;
Complete 不能進行局部合並,直接計算返回最終的結果;
PartialMerge 對聚合緩沖區的數據進行合並,其主要用於distinct語句中,返回的依然是聚合緩沖區數據。
接下來順便介紹下聚合函數分類:
1、DeclarativeAggregate 聲明式的聚合函數
2、ImperativeAggregate 指令式的聚合函數
3、TypedImperativeAggregate是ImperativeAggregate的子類,他可以用java 對象存儲在內存緩沖區中。
聲明的聚合函數和指令式的聚合函數的不同主要體現在update、merge操作上,DeclarativeAggregate對這倆個操作主要是重寫表達式的形式來體現;ImperativeAggregate則要重寫其方法。
接下來介紹下planAggregateWithoutDistinct和planAggregateWithOneDistinct的不同:
關於planAggregateWithoutDistinct:
def planAggregateWithoutDistinct(
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
// Check if we can use HashAggregate.
// 1. Create an Aggregate Operator for partial aggregations.
val groupingAttributes = groupingExpressions.map(_.toAttribute)
val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial))
val partialAggregateAttributes =
partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
val partialResultExpressions =
groupingAttributes ++
partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
val partialAggregate = createAggregate(
requiredChildDistributionExpressions = None,
groupingExpressions = groupingExpressions,
aggregateExpressions = partialAggregateExpressions,
aggregateAttributes = partialAggregateAttributes,
initialInputBufferOffset = 0,
resultExpressions = partialResultExpressions,
child = child)
// 2. Create an Aggregate Operator for final aggregations.
val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
// The attributes of the final aggregation buffer, which is presented as input to the result
// projection:
val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute)
val finalAggregate = createAggregate(
requiredChildDistributionExpressions = Some(groupingAttributes),
groupingExpressions = groupingAttributes,
aggregateExpressions = finalAggregateExpressions,
aggregateAttributes = finalAggregateAttributes,
initialInputBufferOffset = groupingExpressions.length,
resultExpressions = resultExpressions,
child = partialAggregate)
finalAggregate :: Nil
}
上面的方法其實可以總結成倆步,第一步就是創建一個聚合計划用於局部合並階段,第二步就是創建一個final聚合計算。
關於planAggregateWithOneDistinct:
這個其實和上面的planAggregateWithoutDistinct差不太多,只不過是變成了四步:
1、創建一個聚合計划用於局部合並階段
2、創建partialMerge計划;
3、創建一個partial計划 這一步用於distinct
4、創建一個final計划
在這倆個方法里面都用到了createAggregate在這個方法里面確定了到底使用何種方式來實現聚合計算
private def createAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]] = None,
groupingExpressions: Seq[NamedExpression] = Nil,
aggregateExpressions: Seq[AggregateExpression] = Nil,
aggregateAttributes: Seq[Attribute] = Nil,
initialInputBufferOffset: Int = 0,
resultExpressions: Seq[NamedExpression] = Nil,
child: SparkPlan): SparkPlan = {
val useHash = HashAggregateExec.supportsAggregate(
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
if (useHash) {
HashAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
resultExpressions = resultExpressions,
child = child)
} else {
val objectHashEnabled = child.sqlContext.conf.useObjectHashAggregation
val useObjectHash = ObjectHashAggregateExec.supportsAggregate(aggregateExpressions)
if (objectHashEnabled && useObjectHash) {
ObjectHashAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
resultExpressions = resultExpressions,
child = child)
} else {
SortAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
aggregateAttributes = aggregateAttributes,
initialInputBufferOffset = initialInputBufferOffset,
resultExpressions = resultExpressions,
child = child)
}
}
}
從上面的邏輯可以看出來;如果可以進行hashAggregate操作則選取hashAggregate; 他的具體條件是聚合的schema都在下面這些里面就可以采用hashAggregate
static {
mutableFieldTypes = Collections.unmodifiableSet(
new HashSet<>(
Arrays.asList(new DataType[] {
NullType,
BooleanType,
ByteType,
ShortType,
IntegerType,
LongType,
FloatType,
DoubleType,
DateType,
TimestampType
})));
}
之后如果打開了objectHash的開關,並且聚合的函數表達式是TypedImperativeAggregate那么就采用objectHash;然后如何前面倆個都不滿足那么就選擇sortAggregate聚合的方式。下面會介紹下這三種聚合方式。
HashAggregateExec介紹
hashAggregate的邏輯主要是構建一個hashmap,以分組為key,將數據保存在這個map中進行聚合計算,這個map維護在內存中,如果內存不足的情況下,會進行溢寫的操作,之后hashaggregate會退化為基於排序的聚合操作。
在doExecute方法中實例化一個TungstenAggregationIterator,在這個類里面實現了聚合的操作:
1、hashMap = new UnsafeFixedWidthAggregationMap;這個map里面保存了分組的key和其對應的聚合緩沖數據;在UnsafeFixedWidthAggregationMap里面,重要的成員變量有map = BytesToBytesMap 實際保存的數據就在這個map里面。
2、主要邏輯在processInputs方法里面:
private def processInputs(fallbackStartsAt: (Int, Int)): Unit = {
if (groupingExpressions.isEmpty) {
// If there is no grouping expressions, we can just reuse the same buffer over and over again.
// Note that it would be better to eliminate the hash map entirely in the future.
val groupingKey = groupingProjection.apply(null)
val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
while (inputIter.hasNext) {
val newInput = inputIter.next()
processRow(buffer, newInput)
}
} else {
var i = 0
while (inputIter.hasNext) {
val newInput = inputIter.next()
val groupingKey = groupingProjection.apply(newInput)
var buffer: UnsafeRow = null
if (i < fallbackStartsAt._2) {
buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
}
if (buffer == null) {
val sorter = hashMap.destructAndCreateExternalSorter()
if (externalSorter == null) {
externalSorter = sorter
} else {
externalSorter.merge(sorter)
}
i = 0
buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
if (buffer == null) {
// failed to allocate the first page
throw new SparkOutOfMemoryError("No enough memory for aggregation")
}
}
processRow(buffer, newInput)
i += 1
}
if (externalSorter != null) {
val sorter = hashMap.destructAndCreateExternalSorter()
externalSorter.merge(sorter)
hashMap.free()
switchToSortBasedAggregation()
}
}
}
這個方法里面的邏輯大體就是:從inputIter里面獲取數據,然后根據聚合緩沖區的數據對數據進行新增或者更新的操作;在調用hashMap.getAggregationBufferFromUnsafeRow(groupingKey);如果返回的數據為null,那么表示內存不足,這個時候就會進行溢寫操作:
這里的溢寫操作會new UnsafeKVExternalSorter 並返回保存到externalSorter中,如果是初次那么直接賦值,如果不是那么就進行merge;這里會把hashmap里面的map就是bytesbybtesMap的數據會傳進去,之后創建UnsafeInMemorySorter,將bytesbybtesMap的數據導入到UnsafeInMemorySorter里面;在之后調用UnsafeExternalSorter.createWithExistingInMemorySorter,對數據進行排序溢寫
溢寫結束之后會重置bytesbybtesMap然后hashmap繼續申請內存繼續計算,如果內存不足繼續溢寫;直到inputIter沒有元素;
接着會根據externalSorter是否為null來判斷需不需要切換到基於排序聚合操作。
如果不切換基於排序的聚合;則會給aggregationBufferMapIterator和mapIteratorHasNext賦值;
如果切換到基於排序的聚合;那么會調用switchToSortBasedAggregation;初始化一些基於排序的變量;之后會用於next和hasNext方法中:
基於排序的聚合需要的變量有:
externalSorter = UnsafeKVExternalSorter;在switchToSortBasedAggregation里面,externalSorter首先會調用UnsafeKVExternalSorter.sortedIterator方法拿到排序后的record迭代器,之后調用其next就行,這里的next的值就是sortedInputHasNewGroup的值,用於表示是否還有值(這里只是首次,相當於初始化這個變量的值)。
override final def hasNext: Boolean = {
(sortBased && sortedInputHasNewGroup) || (!sortBased && mapIteratorHasNext)
}
override final def next(): UnsafeRow = {
if (hasNext) {
val res = if (sortBased) {
// Process the current group.
processCurrentSortedGroup()
// Generate output row for the current group.
val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer)
// Initialize buffer values for the next group.
sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)
outputRow
} else {
// We did not fall back to sort-based aggregation.
val result =
generateOutput(
aggregationBufferMapIterator.getKey,
aggregationBufferMapIterator.getValue)
// Pre-load next key-value pair form aggregationBufferMapIterator to make hasNext
// idempotent.
mapIteratorHasNext = aggregationBufferMapIterator.next()
if (!mapIteratorHasNext) {
// If there is no input from aggregationBufferMapIterator, we copy current result.
val resultCopy = result.copy()
// Then, we free the map.
hashMap.free()
resultCopy
} else {
result
}
}
numOutputRows += 1
res
} else {
// no more result
throw new NoSuchElementException
}
}
下面看下next和hasnext的實現:如果是基於hash的聚合的hasnext就直接判斷mapitertor里面是否還有元素;如果有那么直接從保存的hashmap里面獲取key和value來組裝輸出以這樣的方式實現next的;
如果是基於排序的聚合next方法會查看sortedInputHasNewGroup;這個值在初始化的時候直接調用的是基於外排的kv存儲(UnsafeKVExternalSorter)的next; 之后在取值的時候,主要的邏輯就是processCurrentSortedGroup方法里面;
// Processes rows in the current group. It will stop when it find a new group.
private def processCurrentSortedGroup(): Unit = {
// First, we need to copy nextGroupingKey to currentGroupingKey.
currentGroupingKey.copyFrom(nextGroupingKey)
// Now, we will start to find all rows belonging to this group.
// We create a variable to track if we see the next group.
var findNextPartition = false
// firstRowInNextGroup is the first row of this group. We first process it.
sortBasedProcessRow(sortBasedAggregationBuffer, firstRowInNextGroup)
// The search will stop when we see the next group or there is no
// input row left in the iter.
// Pre-load the first key-value pair to make the condition of the while loop
// has no action (we do not trigger loading a new key-value pair
// when we evaluate the condition).
var hasNext = sortedKVIterator.next()
while (!findNextPartition && hasNext) {
// Get the grouping key and value (aggregation buffer).
val groupingKey = sortedKVIterator.getKey
val inputAggregationBuffer = sortedKVIterator.getValue
// Check if the current row belongs the current input row.
if (currentGroupingKey.equals(groupingKey)) {
sortBasedProcessRow(sortBasedAggregationBuffer, inputAggregationBuffer)
hasNext = sortedKVIterator.next()
} else {
// We find a new group.
findNextPartition = true
// copyFrom will fail when
nextGroupingKey.copyFrom(groupingKey)
firstRowInNextGroup.copyFrom(inputAggregationBuffer)
}
}
在這個方法里面大體的邏輯就是從sortedKVIterator的迭代器里面取數據,因為數據是基於key排序的,如果key相同那么就繼續取值聚合計算;如果不相同那么就是遇到新值了,這個時候把計算的聚合結果和key返回,當成一次next的返回;
到此這個就是大體的hash聚合的整體流程了;這里面還有一個就是基於排序的對中間數據的聚合計算其實調用的是generateProcessRow;這個方法其實就是基於當前的聚合模式和聚合的函數來決定如何計算聚合函數的值;如果是聲明式的調用updateExpressions或者mergeExpressions、如果是指令式的就調用對應的update和merge方法,這里的計算相當於只是更新聚合緩沖區的數據;
在此之后返回結果的方法:generateResultProjection,這個方法里面也會根據聚合模式和聚合的函數判斷來決定如何計算;返回的是UnsafeProjection。(這里的描述比較亂,跳躍比較大,需要跟着源碼理解,不然長篇大論更加看不懂)。
到此基於hash的聚合計算整體流程算是結束了,這里面有幾個比較重要的點;第一個就是hash的緩存是UnsafeFixedWidthAggregationMap在基於BytesToBytesMap(spark自己實現的hashmap)實現的,第二個就是hash的溢寫最后是在UnsafeExternalSorter里面溢寫UnsafeInMemorySorter里面的數據實現的,第三個就是externalSorter = UnsafeKVExternalSorter 基於排序的聚合其實是依賴UnsafeKVExternalSorter(依賴UnsafeExternalSorter)實現;第四個就是排序的中間緩存數據的計算以及最后結果輸出時的處理。
ObjectHashAggregate
這個類似於hashAggregate,主要的不同就是它是針對TypedImperativeAggregate這種類型的聚合函數來的,他主要是可以將java object緩存在內存中,參與聚合的計算;這里面的聚合緩沖區的定義是 aggBufferIterator = Iterator[AggregationBufferEntry];他的溢寫操作不同於hashAggregate-在計算中多次溢寫,它是溢寫一次就會退化到基於排序的聚合。大體的邏輯和hashAggragate的差不多。
SortAggregateExec
基於排序的聚合操作的原理就是數據根據key進行排序,之后順序讀取數據,如果key相同那么就進行聚合函數的計算,如果不同那么代表遇到了新的key;那么就重新計算新的聚合結果。
這里的實現和在hashAggregate里面的實現大同小異,主要的思想沒有變,就連計算中間的聚合函數結果的方法都是用的同一個;這里有一個需要注意的點就是:
override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
}
這里對子節點的排序做了要求,所以在准備階段的話,在sortAggregate之前會增加排序的操作,感興趣的同學可以參考hashAggregate中對基於排序的聚合計算的描述來理解這里的基於排序的聚合計算過程。
到此整個聚合計算的過程已經總結完畢。中間還有很多可以展開的東西,但是這里只是總結聚合的操作,其他的可以在后續單獨總結。
