概述
RDD中最重要的一项是partition,RDD的五个要素里面有两个牵扯到partition,如下所示,研究partition不仅要研究Partition的定义,还要研究不同RDD的如下两个方法。
// 对特定分配的分区进行操作
def compute(split: Partition, context: TaskContext): Iterator[T]
// 获取所有的partition
protected def getPartitions: Array[Partition]
Partition
Partition是一个trait,定义比较简单,一个变量表明是第几个partition,一个hashCode和equal方法定义,每种类型的RDD都会有相应的Partition,继承于trait Partition,下文着重分析常见的Parition。
trait Partition extends Serializable {
def index: Int
override def hashCode(): Int = index
override def equals(other: Any): Boolean = super.equals(other)
}
JdbcPartition
JdbcParition定义
JdbcPartition定义如下:继承于Partition并且多了两个变量。
// 相对于partition trait,多定义了两个变量,lower和upper,供sql获取数据范围
private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long)
extends Partition {
override def index: Int = idx
}
private[spark] 表示这个类只能在包名中含有spark的类中访问
JdbcRDD定义
class JdbcRDD[T: ClassTag](
sc: SparkContext,
getConnection: () => Connection,
sql: String,
lowerBound: Long,
upperBound: Long,
numPartitions: Int,
mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _)
extends RDD[T](sc, Nil) with Logging
如上是JdbcRDD的定义,继承于RDD,扩展了几个字段:
-
Connection定义连接函数,返回值是Connection
-
sql是读取数据的逻辑sql
-
lowerBound是sql逻辑中读取数据的lower界限,开区间
-
upperBound是是sql逻辑中读取数据的lower界限,开区间
-
numPartitions定义了多少个partition
-
mapRow是一个方法,以ResultSet为输入参数,输出是想要得到的类型的单独一行,只能调用getInt、getString等方法,RDD负责调用next,默认将ResultSet映射到一个对象数组。
def resultSetToObjectArray(rs: ResultSet): Array[Object] = { // 获取resutSet的meta信息,对于每一列数据,获取object并放入到array数组中。 Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) } 可以自定义mapRow方法: 输出一个字符串,可以返回各种类型 def maprow(rs: ResultSet): String = { @volatile var res: String = "" @volatile var i = 1 val count = rs.getMetaData.getColumnCount while (i < count) { if (i != 1) res = res +", " + rs.getString(i) else res = rs.getString(i) i = i + 1 } res }
getPartitions
代码如下:
override def getPartitions: Array[Partition] = {
val length = BigInt(1) + upperBound - lowerBound
(0 until numPartitions).map { i =>
val start = lowerBound + ((i * length) / numPartitions)
val end = lowerBound + (((i + 1) * length) / numPartitions) - 1
new JdbcPartition(i, start.toLong, end.toLong)
}.toArray
}
可以看到,JdbcRDD分区的方法是对分区数据进行平均切分数据区间,(lowerBound)到(upperBound)分为(numPartitions)份,封装为JdbcPartition并返回,完成数据切分。
compute
override def compute(thePart: Partition, context: TaskContext): Iterator[T] =
new NextIterator[T] {
context.addTaskCompletionListener{ context => closeIfNeeded() }
// 获取相应分区
val part = thePart.asInstanceOf[JdbcPartition]
// 拿取定义的连接函数
val conn = getConnection()
val stmt = conn.prepareStatement(sql,
ResultSet.TYPE_FORWARD_ONLY,
ResultSet.CONCUR_READ_ONLY)
val url = conn.getMetaData.getURL
if (url.startsWith("jdbc:mysql:")) {
stmt.setFetchSize(Integer.MIN_VALUE)
} else {
stmt.setFetchSize(100)
}
logInfo(s"statement fetch size set to: ${stmt.getFetchSize}")
stmt.setLong(1, part.lower)
stmt.setLong(2, part.upper)
val rs = stmt.executeQuery() // 执行sql
}
JdbcRDD的compute方法首先将Partition强转为JdbcPartition,获取连接并预处理sql,将sql中的参数分别用Partition的lower和upper替换并执行查询,可以看到此处并没有得到结果的处理,原因是这个是惰性的执行。
实例分析
经过上面介绍,我们来分析一个使用JdbcRDD的例子:
new JdbcRDD(
sc,
() => { DriverManager.getConnection(URL, USERNAME, PASSWORD) },
"select title, author from books where ? <= id and id <= ?",
1,
100,
3,
(r: ResultSet) => { r.getString(1) + "," + r.getString(2) }
).count()
如上,定义了一个JdbcRDD,给定的sql语句是SELECT gender FROM people WHERE ? <= ID AND ID <= ?
,选取的都区范围是[1, 100],分为三个分区,并且指定了mapRow方法:简单的获取两列数值并且放到一个字符串中返回。
- getPartition方法:分为三个分区,(1, 100)分为三份:(1,33)、(34,66)、(67,100)
- compute方法:对各个分区进行获取sql连接,填充id的开始即结束值,然后执行查询获取数据操作。
参考
- https://blog.csdn.net/u011564172/article/details/53611109