SparkSQL2.0扩展外部数据源原理(读取外部系统)

比较灵活,可以连接外部任何系统,spark-jdbc、Cassandra的实现方式都是这种。需要实现下面几个类

  • DefaultSource

入口类,用来建立外部数据源连接,SparkSQL默认会找这个名字,不要改类名。

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider}
import org.slf4j.LoggerFactory

private[tsfile] class DefaultSource extends RelationProvider with DataSourceRegister {
  private final val logger = LoggerFactory.getLogger(classOf[DefaultSource])

  override def shortName(): String = "tsfile"

  override def createRelation(
                               sqlContext: SQLContext,
                               parameters: Map[String, String]): BaseRelation = {

    val tsfileOptions = new TSFileOptions(parameters)

    if (tsfileOptions.url == null || tsfileOptions.sql == null) {
      sys.error("TSFile node or sql not specified")
    }
    new TSFileRelation(tsfileOptions)(sqlContext.sparkSession)

  }
}
  • 自定义Relation

最主要的类,用来向SparkSQL返回表结构和数据的接口都在这里

import org.apache.spark.Partition
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext, SparkSession}
import org.apache.spark.sql.sources.{BaseRelation, Filter, PrunedFilteredScan}
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}
import org.slf4j.LoggerFactory

//用来分区的信息
private case class TSFilePartitioningInfo(
                                           start: Long,
                                           end: Long,
                                           numPartitions: Int)


private object TSFileRelation {

  private final val logger = LoggerFactory.getLogger(classOf[TSFileRelation])

	//计算分区
  def getPartitions(partitionInfo: TSFilePartitioningInfo): Array[Partition] = {
  return Array[Partition](TSFilePartition(null, 0, 0L, 0L))
 }
}
    
class TSFileRelation protected[tsfile](val options: TSFileOptions)(@transient val sparkSession: SparkSession)
  extends BaseRelation with PrunedFilteredScan {

  override def sqlContext: SQLContext = sparkSession.sqlContext

  private final val logger = LoggerFactory.getLogger(classOf[TSFileRelation])

	//通过自己构造的参数构造表结构
  override def schema: StructType = {
    Converter.toSparkSchema(options)
  }

	//根据SparkSQL传来的列名和过滤条件构造自己的RDD
  override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
    val start: Long = options.lowerBound.toLong
    val end: Long = options.upperBound.toLong
    val numPartition = options.numPartition.toInt

    val partitionInfo = TSFilePartitioningInfo(start, end, numPartition)

    val parts = TSFileRelation.getPartitions(partitionInfo)

    new TSFileRDD(sparkSession.sparkContext,
      options,
      schema,
      requiredColumns,
      filters,
      parts).asInstanceOf[RDD[Row]]
  }
}

  • 自定义RDD

继承RDD,重写分区方法getPartitions,和返回每个分区数据的方法compute

//每个partition的信息,继承Partition
case class TSFilePartition(where: String, id: Int, start: java.lang.Long, end:java.lang.Long) extends Partition {
  override def index: Int = id
}



class TSFileRDD private[tsfile](
                                sc: SparkContext,
                                options: TSFileOptions,
                                schema : StructType,
                                requiredColumns: Array[String],
                                filters: Array[Filter],
                                partitions: Array[Partition])
  extends RDD[Row](sc, Nil) {

	//给定一个分区,返回其中的数据
  override def compute(split: Partition, context: TaskContext): Iterator[Row] = new Iterator[Row] {
    var finished = false
    var gotNext = false
    var nextValue: Row = null
    val inputMetrics = context.taskMetrics().inputMetrics

    val part = split.asInstanceOf[TSFilePartition]
 
    var taskInfo: String = _
    Option(TaskContext.get()).foreach { taskContext => {
      taskContext.addTaskCompletionListener { _ => conn.close()}
      taskInfo = "task Id: " + taskContext.taskAttemptId() + " partition Id: " + taskContext.partitionId()
      println(taskInfo)
    }
    }
	
	//存放每个Row的数据
    private val rowBuffer = Array.fill[Any](schema.length)(null)

	
    override def hasNext: Boolean = {
    	false
    }

    override def next(): Row = {
       rowBuffer(index) = null
       Row.fromSeq(rowBuffer)
    }
  }

  override def getPartitions: Array[Partition] = partitions

}

调用示例:

```
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext

object test {
  def main(args: Array[String]): Unit = {
    var sqlContext = new SQLContext(new SparkContext("local[2]", "readIoTDB"))
    val createSql = "create temporary view iotdb using 包名 options(url = \"jdbc:tsfile://127.0.0.1:6667/\", union = \"true\")"
    sqlContext.sql(createSql)
    val df = sqlContext.sql("select * from iotdb")

    df.show()

  }

}
```

完整代码:

https://github.com/qiaojialin/spark-iotdb-connector

有兴趣可以关注公众号:数据库漫游指南

在这里插入图片描述

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值