我們都知道Spark內部提供了HashPartitioner
和RangePartitioner
兩種分區策略,這兩種分區策略在很多情況下都適合我們的場景。但是有些情況下,Spark內部不能符合咱們的需求,這時候我們就可以自定義分區策略。為此,Spark提供了相應的接口,我們只需要擴展Partitioner
抽象類,然后實現里面的三個方法:
01 package org.apache.spark 02 03 /** 04 * An object that defines how the elements in a key-value pair RDD are partitioned by key. 05 * Maps each key to a partition ID, from 0 to `numPartitions - 1`. 06 */ 07 abstract class Partitioner extends Serializable { 08 def numPartitions: Int 09 def getPartition(key: Any): Int 10 }
def numPartitions: Int
:這個方法需要返回你想要創建分區的個數;
def getPartition(key: Any): Int
:這個函數需要對輸入的key做計算,然后返回該key的分區ID,范圍一定是0到numPartitions-1
;
equals()
:這個是Java標准的判斷相等的函數,之所以要求用戶實現這個函數是因為Spark內部會比較兩個RDD的分區是否一樣。
假如我們想把來自同一個域名的URL放到一台節點上,比如:http://www.iteblog.com
和http://www.iteblog.com/archives/1368
,如果你使用HashPartitioner
,這兩個URL的Hash值可能不一樣,這就使得這兩個URL被放到不同的節點上。所以這種情況下我們就需要自定義我們的分區策略,可以如下實現:
01 package com.iteblog.utils 02 03 import org.apache.spark.Partitioner 04 05 /** 06 * User: 過往記憶 07 * Date: 2015-05-21 08 * Time: 下午23:34 09 * bolg: http://www.iteblog.com 10 * 本文地址:http://www.iteblog.com/archives/1368 11 * 過往記憶博客,專注於hadoop、hive、spark、shark、flume的技術博客,大量的干貨 12 * 過往記憶博客微信公共帳號:iteblog_hadoop 13 */ 14 15 class IteblogPartitioner(numParts: Int) extends Partitioner { 16 override def numPartitions: Int = numParts 17 18 override def getPartition(key: Any): Int = { 19 val domain = new java.net.URL(key.toString).getHost() 20 val code = (domain.hashCode % numPartitions) 21 if (code < 0) { 22 code + numPartitions 23 } else { 24 code 25 } 26 } 27 28 override def equals(other: Any): Boolean = other match { 29 case iteblog: IteblogPartitioner => 30 iteblog.numPartitions == numPartitions 31 case _ => 32 false 33 } 34 35 override def hashCode: Int = numPartitions 36 }
因為hashCode值可能為負數,所以我們需要對他進行處理。然后我們就可以在partitionBy()方法里面
使用我們的分區:
1 iteblog.partitionBy(new IteblogPartitioner(20))
類似的,在Java中定義自己的分區策略和Scala類似,只需要繼承org.apache.spark.Partitioner,並實現其中的方法即可。
在Python中,你不需要擴展Partitioner類,我們只需要對iteblog.partitionBy()加上一個額外的hash函數,如下:
1 import urlparse 2 3 def iteblog_domain(url): 4 return hash(urlparse.urlparse(url).netloc) 5 6 iteblog.partitionBy(20, iteblog_domain)