Spark自定义分区(Partitioner)

时间:2021-08-22 20:55:51

我们都知道Spark内部提供了HashPartitionerRangePartitioner两种分区策略(这两种分区的代码解析可以参见:《Spark分区器HashPartitioner和RangePartitioner代码详解》),这两种分区策略在很多情况下都适合我们的场景。但是有些情况下,Spark内部不能符合咱们的需求,这时候我们就可以自定义分区策略。为此,Spark提供了相应的接口,我们只需要扩展Partitioner抽象类,然后实现里面的三个方法:

01 package org.apache.spark
02  
03 /**
04  * An object that defines how the elements in a key-value pair RDD are partitioned by key.
05  * Maps each key to a partition ID, from 0 to `numPartitions - 1`.
06  */
07 abstract class Partitioner extends Serializable {
08   def numPartitions: Int
09   def getPartition(key: Any): Int
10 }

  def numPartitions:
Int
:这个方法需要返回你想要创建分区的个数;
  def getPartition(key:
Any): Int
:这个函数需要对输入的key做计算,然后返回该key的分区ID,范围一定是0到numPartitions-1
  equals():这个是Java标准的判断相等的函数,之所以要求用户实现这个函数是因为Spark内部会比较两个RDD的分区是否一样。

  假如我们想把来自同一个域名的URL放到一台节点上,比如:http://www.iteblog.comhttp://www.iteblog.com/archives/1368,如果你使用HashPartitioner,这两个URL的Hash值可能不一样,这就使得这两个URL被放到不同的节点上。所以这种情况下我们就需要自定义我们的分区策略,可以如下实现:

01 package com.iteblog.utils
02  
03 import org.apache.spark.Partitioner
04  
05 /**
06  * User: 过往记忆
07  * Date: 2015-05-21
08  * Time: 下午23:34
09  * bolg: http://www.iteblog.com
10  * 本文地址:http://www.iteblog.com/archives/1368
11  * 过往记忆博客,专注于hadoop、hive、spark、shark、flume的技术博客,大量的干货
12  * 过往记忆博客微信公共帐号:iteblog_hadoop
13  */
14  
15 class IteblogPartitioner(numParts: Int) extends Partitioner {
16   override def numPartitions: Int = numParts
17  
18   override def getPartition(key: Any): Int = {
19     val domain = new java.net.URL(key.toString).getHost()
20     val code = (domain.hashCode % numPartitions)
21     if (code < 0) {
22       code + numPartitions
23     else {
24       code
25     }
26   }
27  
28   override def equals(other: Any): Boolean = other match {
29     case iteblog: IteblogPartitioner =>
30       iteblog.numPartitions == numPartitions
31     case _ =>
32       false
33   }
34  
35   override def hashCode: Int = numPartitions
36 }

因为hashCode值可能为负数,所以我们需要对他进行处理。然后我们就可以在partitionBy()方法里面使用我们的分区:

1 iteblog.partitionBy(new IteblogPartitioner(20))

  类似的,在Java中定义自己的分区策略和Scala类似,只需要继承org.apache.spark.Partitioner,并实现其中的方法即可。

  在Python中,你不需要扩展Partitioner类,我们只需要对iteblog.partitionBy()加上一个额外的hash函数,如下:

查看源代码打印帮助
1 import urlparse
2  
3 def iteblog_domain(url):
4   return hash(urlparse.urlparse(url).netloc)
5  
6 iteblog.partitionBy(20, iteblog_domain)

 上述部分转载自(http://www.iteblog.com/) 下述部分转载自 http://blog.csdn.net/bluejoe2000/article/details/41415087

RDD是个抽象类,定义了诸如map()、reduce()等方法,但实际上继承RDD的派生类一般只要实现两个方法:

  • def getPartitions: Array[Partition]
  • def compute(thePart: Partition, context: TaskContext): NextIterator[T]

getPartitions()用来告知怎么将input分片;

compute()用来输出每个Partition的所有行(行是我给出的一种不准确的说法,应该是被函数处理的一个单元);

以一个hdfs文件HadoopRDD为例:

[java] view plain copy print?Spark自定义分区(Partitioner)Spark自定义分区(Partitioner)
  1. override def getPartitions: Array[Partition] = {  
  2.   val jobConf = getJobConf()  
  3.   // add the credentials here as this can be called before SparkContext initialized  
  4.   SparkHadoopUtil.get.addCredentials(jobConf)  
  5.   val inputFormat = getInputFormat(jobConf)  
  6.   if (inputFormat.isInstanceOf[Configurable]) {  
  7.     inputFormat.asInstanceOf[Configurable].setConf(jobConf)  
  8.   }  
  9.   val inputSplits = inputFormat.getSplits(jobConf, minPartitions)  
  10.   val array = new Array[Partition](inputSplits.size)  
  11.   for (i <- 0 until inputSplits.size) {  
  12.     array(i) = new HadoopPartition(id, i, inputSplits(i))  
  13.   }  
  14.   array  
  15. }  
它直接将各个split包装成RDD了,再看compute():

[java] view plain copy print?Spark自定义分区(Partitioner)Spark自定义分区(Partitioner)
  1. override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {  
  2.   val iter = new NextIterator[(K, V)] {  
  3.   
  4.     val split = theSplit.asInstanceOf[HadoopPartition]  
  5.     logInfo("Input split: " + split.inputSplit)  
  6.     var reader: RecordReader[K, V] = null  
  7.     val jobConf = getJobConf()  
  8.     val inputFormat = getInputFormat(jobConf)  
  9.     HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),  
  10.       context.stageId, theSplit.index, context.attemptId.toInt, jobConf)  
  11.     reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)  
  12.   
  13.     // Register an on-task-completion callback to close the input stream.  
  14.     context.addTaskCompletionListener{ context => closeIfNeeded() }  
  15.     val key: K = reader.createKey()  
  16.     val value: V = reader.createValue()  
  17.   
  18.     // Set the task input metrics.  
  19.     val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)  
  20.     try {  
  21.       /* bytesRead may not exactly equal the bytes read by a task: split boundaries aren't 
  22.        * always at record boundaries, so tasks may need to read into other splits to complete 
  23.        * a record. */  
  24.       inputMetrics.bytesRead = split.inputSplit.value.getLength()  
  25.     } catch {  
  26.       case e: java.io.IOException =>  
  27.         logWarning("Unable to get input size to set InputMetrics for task", e)  
  28.     }  
  29.     context.taskMetrics.inputMetrics = Some(inputMetrics)  
  30.   
  31.     override def getNext() = {  
  32.       try {  
  33.         finished = !reader.next(key, value)  
  34.       } catch {  
  35.         case eof: EOFException =>  
  36.           finished = true  
  37.       }  
  38.       (key, value)  
  39.     }  
  40.   
  41.     override def close() {  
  42.       try {  
  43.         reader.close()  
  44.       } catch {  
  45.         case e: Exception => logWarning("Exception in RecordReader.close()", e)  
  46.       }  
  47.     }  
  48.   }  
  49.   new InterruptibleIterator[(K, V)](context, iter)  
  50. }  
它调用reader返回一系列的K,V键值对。

再来看看数据库的JdbcRDD:

[java] view plain copy print?Spark自定义分区(Partitioner)Spark自定义分区(Partitioner)
  1. override def getPartitions: Array[Partition] = {  
  2.   // bounds are inclusive, hence the + 1 here and - 1 on end  
  3.   val length = 1 + upperBound - lowerBound  
  4.   (0 until numPartitions).map(i => {  
  5.     val start = lowerBound + ((i * length) / numPartitions).toLong  
  6.     val end = lowerBound + (((i + 1) * length) / numPartitions).toLong - 1  
  7.     new JdbcPartition(i, start, end)  
  8.   }).toArray  
  9. }  
它直接将结果集分成numPartitions份。其中很多参数都来自于构造函数:

[java] view plain copy print?Spark自定义分区(Partitioner)Spark自定义分区(Partitioner)
  1. class JdbcRDD[T: ClassTag](  
  2.     sc: SparkContext,  
  3.     getConnection: () => Connection,  
  4.     sql: String,  
  5.     lowerBound: Long,  
  6.     upperBound: Long,  
  7.     numPartitions: Int,  
  8.     mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _)  
再看看compute()函数:

[java] view plain copy print?Spark自定义分区(Partitioner)Spark自定义分区(Partitioner)
  1. override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] {  
  2.   context.addTaskCompletionListener{ context => closeIfNeeded() }  
  3.   val part = thePart.asInstanceOf[JdbcPartition]  
  4.   val conn = getConnection()  
  5.   val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)  
  6.   
  7.   // setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results,  
  8.   // rather than pulling entire resultset into memory.  
  9.   // see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html  
  10.   if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) {  
  11.     stmt.setFetchSize(Integer.MIN_VALUE)  
  12.     logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ")  
  13.   }  
  14.   
  15.   stmt.setLong(1, part.lower)  
  16.   stmt.setLong(2, part.upper)  
  17.   val rs = stmt.executeQuery()  
  18.   
  19.   override def getNext: T = {  
  20.     if (rs.next()) {  
  21.       mapRow(rs)  
  22.     } else {  
  23.       finished = true  
  24.       null.asInstanceOf[T]  
  25.     }  
  26.   }  
  27.   
  28.   override def close() {  
  29.     try {  
  30.       if (null != rs && ! rs.isClosed()) {  
  31.         rs.close()  
  32.       }  
  33.     } catch {  
  34.       case e: Exception => logWarning("Exception closing resultset", e)  
  35.     }  
  36.     try {  
  37.       if (null != stmt && ! stmt.isClosed()) {  
  38.         stmt.close()  
  39.       }  
  40.     } catch {  
  41.       case e: Exception => logWarning("Exception closing statement", e)  
  42.     }  
  43.     try {  
  44.       if (null != conn && ! conn.isClosed()) {  
  45.         conn.close()  
  46.       }  
  47.       logInfo("closed connection")  
  48.     } catch {  
  49.       case e: Exception => logWarning("Exception closing connection", e)  
  50.     }  
  51.   }  
  52. }