在我们使用JdbcRDD时系统默认的参数如下:
sc: SparkContext, getConnection: () => Connection, sql: String, lowerBound: Long, upperBound: Long, numPartitions: Int, mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _
根据其注释的说明:
select title, author from books where ? <= id and id <= ?
* @param lowerBound the minimum value of the first placeholder * @param upperBound the maximum value of the second placeholder * The lower and upper bounds are inclusive. * @param numPartitions the number of partitions. * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, * the query would be executed twice, once with (1, 10) and once with (11, 20)
由上上面的内容可以发现,JdbcRDD中的主构造函数中这几个参数是必不可少的,且没有辅助构造函数可以使用,于是我们在查询时就不得不输入上下界,即必须输入有查询条件的sql,然后以参数的形式传入JdbcRDD的主构造函数中。我们在实际的使用中,或者在测试中,我们需要不带参数进行使用就显得无能为力,为此,我们该如何做呢?
方法可能有很多,对我们来说,简单的实现由两种方式,即自己实现JdbcRDD和继承JdbcRDD,自己定义辅构造函数。本文只实现自己重新定义JdbcRDD,降低程序的耦合度。
通过查看JdbcRDD的源码发现,其实,
lowerBound 用于定义查询的下标upperBound 用于定义查询的上标numPartitions 用于定义查询的分区数
这三个参数在实际的生产环境中,可能很有用,通过该三个参数定义每个分区查询的范数据围,这也是spark人员设计时一定加上该参数的原因。说明:
本例仅仅是简单的去掉该三个参数,需要知道的是方式不止这一种,且由于把分区参数去掉了,本代码默认的是一个分区,可以在代码中手动的设置多个分区。
修改JdbcRDD的源码,同时需要修改有NextIterator.scala(其实只是挪一下位置,源码不动的挪过来,由于源码是spark包下的private,所以不能引用在其他的包内),该文件就不再粘贴出来。
修改后的JdbcRDD.scala改名为JDBCRDD.scala,NextIterator.scala放在与JDBCRDD.scala同一个包内。以下是JDBCRDD.scala源码import java.sql.{Connection, ResultSet} import scala.reflect.ClassTag import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD /** * Created by Administrator on 2017/9/8. */ class JDBCPartition(idx: Int) extends Partition { override def index: Int = idx } class JDBCRDD[T: ClassTag]( sc: SparkContext, getConnection: () => Connection, sql: String, mapRow: (ResultSet) => T = JDBCRDD.resultSetToObjectArray _) extends RDD[T](sc, Nil) with Logging { override def getPartitions: Array[Partition] = { (0 to 1).map { i => new JDBCPartition(i) }.toArray } override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T] { context.addTaskCompletionListener { context => closeIfNeeded() } val part = thePart.asInstanceOf[JDBCPartition] val conn = getConnection() val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) val url = conn.getMetaData.getURL if (url.startsWith("jdbc:mysql:")) { stmt.setFetchSize(Integer.MIN_VALUE) } else { stmt.setFetchSize(100) } logInfo(s"statement fetch size set to: ${stmt.getFetchSize}") val rs = stmt.executeQuery() override def getNext(): T = { if (rs.next()) { mapRow(rs) } else { finished = true null.asInstanceOf[T] } } override def close() { try { if (null != rs) { rs.close() } } catch { case e: Exception => logWarning("Exception closing resultset", e) } try { if (null != stmt) { stmt.close() } } catch { case e: Exception => logWarning("Exception closing statement", e) } try { if (null != conn) { conn.close() } logInfo("closed connection") } catch { case e: Exception => logWarning("Exception closing connection", e) } } } } object JDBCRDD { def resultSetToObjectArray(rs: ResultSet): Array[Object] = { Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) } trait ConnectionFactory extends Serializable { @throws[Exception] def getConnection: Connection } def fakeClassTag[T]: ClassTag[T] = ClassTag.AnyRef.asInstanceOf[ClassTag[T]] def create[T]( sc: JavaSparkContext, connectionFactory: ConnectionFactory, sql: String, mapRow: JFunction[ResultSet, T]): JavaRDD[T] = { val JDBCRDD = new JDBCRDD[T]( sc.sc, () => connectionFactory.getConnection, sql, (resultSet: ResultSet) => mapRow.call(resultSet))(fakeClassTag) new JavaRDD[T](JDBCRDD)(fakeClassTag) } def create( sc: JavaSparkContext, connectionFactory: ConnectionFactory, sql: String ): JavaRDD[Array[Object]] = { val mapRow = new JFunction[ResultSet, Array[Object]] { override def call(resultSet: ResultSet): Array[Object] = { resultSetToObjectArray(resultSet) } } create(sc, connectionFactory, sql, mapRow) } }
以下是测试上面JDBCRDD.scala的例子
import java.sql.DriverManager import org.apache.spark.{SparkConf, SparkContext} /** * Created by Administrator on 2017/9/8. */ object TestJDBC { def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("TestJDBC").setMaster("local[2]") val sc = new SparkContext(conf) try { val connection = () => { Class.forName("com.mysql.jdbc.Driver").newInstance() DriverManager.getConnection("jdbc:mysql://192.168.0.4:3306/spark", "root", "root") } val JDBCRDD = new JDBCRDD( sc, connection, "SELECT * FROM result", r => { val id = r.getInt(1) val code = r.getString(2) (id, code) } ) val jrdd = JDBCRDD.collect() println(JDBCRDD.collect().toBuffer) sc.stop() } catch { case e: Exception => println(e.printStackTrace()) } } }
简单的修改JdbcRDD的源码到此就完成了。希望对你有用。