我们都知道Spark内部提供了HashPartitioner
和RangePartitioner
两种分区策略(这两种分区的代码解析可以参见:《Spark分区器HashPartitioner和RangePartitioner代码详解》),这两种分区策略在很多情况下都适合我们的场景。但是有些情况下,Spark内部不能符合咱们的需求,这时候我们就可以自定义分区策略。为此,Spark提供了相应的接口,我们只需要扩展Partitioner
抽象类,然后实现里面的三个方法:
01 |
package org.apache.spark
|
07 |
abstract class Partitioner extends Serializable {
|
08 |
def numPartitions : Int
|
09 |
def getPartition(key : Any) : Int
|
def numPartitions:
Int
:这个方法需要返回你想要创建分区的个数;
def getPartition(key:
Any): Int
:这个函数需要对输入的key做计算,然后返回该key的分区ID,范围一定是0到numPartitions-1
;
equals()
:这个是Java标准的判断相等的函数,之所以要求用户实现这个函数是因为Spark内部会比较两个RDD的分区是否一样。
假如我们想把来自同一个域名的URL放到一台节点上,比如:http://www.iteblog.com
和http://www.iteblog.com/archives/1368
,如果你使用HashPartitioner
,这两个URL的Hash值可能不一样,这就使得这两个URL被放到不同的节点上。所以这种情况下我们就需要自定义我们的分区策略,可以如下实现:
01 |
package com.iteblog.utils
|
03 |
import org.apache.spark.Partitioner
|
15 |
class IteblogPartitioner(numParts : Int) extends Partitioner {
|
16 |
override def numPartitions : Int = numParts
|
18 |
override def getPartition(key : Any) : Int = {
|
19 |
val domain = new java.net.URL(key.toString).getHost()
|
20 |
val code = (domain.hashCode % numPartitions)
|
28 |
override def equals(other : Any) : Boolean = other match {
|
29 |
case iteblog : IteblogPartitioner = >
|
30 |
iteblog.numPartitions == numPartitions
|
35 |
override def hashCode : Int = numPartitions
|
因为hashCode
值可能为负数,所以我们需要对他进行处理。然后我们就可以在partitionBy()
方法里面使用我们的分区:
1 |
iteblog.partitionBy( new IteblogPartitioner( 20 ))
|
类似的,在Java中定义自己的分区策略和Scala类似,只需要继承org.apache.spark.Partitioner
,并实现其中的方法即可。
在Python中,你不需要扩展Partitioner类,我们只需要对iteblog.partitionBy()
加上一个额外的hash函数,如下:
查看源代码打印帮助
3 |
def iteblog_domain(url):
|
4 |
return hash (urlparse.urlparse(url).netloc)
|
6 |
iteblog.partitionBy( 20 , iteblog_domain)
|
上述部分转载自(http://www.iteblog.com/) 下述部分转载自 http://blog.csdn.net/bluejoe2000/article/details/41415087RDD是个抽象类,定义了诸如map()、reduce()等方法,但实际上继承RDD的派生类一般只要实现两个方法:
- def getPartitions: Array[Partition]
-
def compute(thePart: Partition, context: TaskContext): NextIterator[T]
getPartitions()用来告知怎么将input分片;
compute()用来输出每个Partition的所有行(行是我给出的一种不准确的说法,应该是被函数处理的一个单元);
以一个hdfs文件HadoopRDD为例:
[java] view plain copy print?
- override def getPartitions: Array[Partition] = {
- val jobConf = getJobConf()
-
- SparkHadoopUtil.get.addCredentials(jobConf)
- val inputFormat = getInputFormat(jobConf)
- if (inputFormat.isInstanceOf[Configurable]) {
- inputFormat.asInstanceOf[Configurable].setConf(jobConf)
- }
- val inputSplits = inputFormat.getSplits(jobConf, minPartitions)
- val array = new Array[Partition](inputSplits.size)
- for (i <- 0 until inputSplits.size) {
- array(i) = new HadoopPartition(id, i, inputSplits(i))
- }
- array
- }
它直接将各个split包装成RDD了,再看compute():
[java] view plain copy print?
- override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
- val iter = new NextIterator[(K, V)] {
-
- val split = theSplit.asInstanceOf[HadoopPartition]
- logInfo("Input split: " + split.inputSplit)
- var reader: RecordReader[K, V] = null
- val jobConf = getJobConf()
- val inputFormat = getInputFormat(jobConf)
- HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
- context.stageId, theSplit.index, context.attemptId.toInt, jobConf)
- reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
-
-
- context.addTaskCompletionListener{ context => closeIfNeeded() }
- val key: K = reader.createKey()
- val value: V = reader.createValue()
-
-
- val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
- try {
-
-
-
- inputMetrics.bytesRead = split.inputSplit.value.getLength()
- } catch {
- case e: java.io.IOException =>
- logWarning("Unable to get input size to set InputMetrics for task", e)
- }
- context.taskMetrics.inputMetrics = Some(inputMetrics)
-
- override def getNext() = {
- try {
- finished = !reader.next(key, value)
- } catch {
- case eof: EOFException =>
- finished = true
- }
- (key, value)
- }
-
- override def close() {
- try {
- reader.close()
- } catch {
- case e: Exception => logWarning("Exception in RecordReader.close()", e)
- }
- }
- }
- new InterruptibleIterator[(K, V)](context, iter)
- }
它调用reader返回一系列的K,V键值对。
再来看看数据库的JdbcRDD:
[java] view plain copy print?
- override def getPartitions: Array[Partition] = {
-
- val length = 1 + upperBound - lowerBound
- (0 until numPartitions).map(i => {
- val start = lowerBound + ((i * length) / numPartitions).toLong
- val end = lowerBound + (((i + 1) * length) / numPartitions).toLong - 1
- new JdbcPartition(i, start, end)
- }).toArray
- }
它直接将结果集分成numPartitions份。其中很多参数都来自于构造函数:
[java] view plain copy print?
- class JdbcRDD[T: ClassTag](
- sc: SparkContext,
- getConnection: () => Connection,
- sql: String,
- lowerBound: Long,
- upperBound: Long,
- numPartitions: Int,
- mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _)
再看看compute()函数:
[java] view plain copy print?
- override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] {
- context.addTaskCompletionListener{ context => closeIfNeeded() }
- val part = thePart.asInstanceOf[JdbcPartition]
- val conn = getConnection()
- val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
-
-
-
-
- if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) {
- stmt.setFetchSize(Integer.MIN_VALUE)
- logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ")
- }
-
- stmt.setLong(1, part.lower)
- stmt.setLong(2, part.upper)
- val rs = stmt.executeQuery()
-
- override def getNext: T = {
- if (rs.next()) {
- mapRow(rs)
- } else {
- finished = true
- null.asInstanceOf[T]
- }
- }
-
- override def close() {
- try {
- if (null != rs && ! rs.isClosed()) {
- rs.close()
- }
- } catch {
- case e: Exception => logWarning("Exception closing resultset", e)
- }
- try {
- if (null != stmt && ! stmt.isClosed()) {
- stmt.close()
- }
- } catch {
- case e: Exception => logWarning("Exception closing statement", e)
- }
- try {
- if (null != conn && ! conn.isClosed()) {
- conn.close()
- }
- logInfo("closed connection")
- } catch {
- case e: Exception => logWarning("Exception closing connection", e)
- }
- }
- }