SparkSQL自定义外部数据源源码分析及案例实现

通过查看JDBC方式源代码入口分析:

源码分析

//继承BaseRelation的类必须能够以`StructType`的形式产生其数据模式。具体的实现应继承自后代Scan类之一
abstract class BaseRelation {
  def sqlContext: SQLContext
  def schema: StructType

def sizeInBytes: Long = sqlContext.conf.defaultSizeInBytes
def needConversion: Boolean = true
def unhandledFilters(filters: Array[Filter]): Array[Filter] = filters
}


// 全表扫描 相当于 select * from xxx
trait TableScan {
  def buildScan(): RDD[Row]
}

// 列裁剪  过滤掉不需要的列
trait PrunedScan {
  def buildScan(requiredColumns: Array[String]): RDD[Row]
}

// 列裁剪加上行过滤  有点类似于 select col1,col2 ... limit 10
trait PrunedFilteredScan {
  def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row]
}

//可写入
trait InsertableRelation {
  def insert(data: DataFrame, overwrite: Boolean): Unit
}

案例实现

需要读取的数据(没有字段类型和schema):

101,zhansan,0,10000,200000
102,lisi,0,150000,250000
103,wangwu,1,3000,5
104,zhaoliu,2,500,6
102,lisi,0,250000,250000

代码

//注意:必须以DefaultSource 为类名,如果不以这个为类名需要指定一个datasource的名字,否则Spark SQL会将类名DefaultSource附加到路径中,以减少冗长的调用 比如:“org.apache.spark.sql.json”将解析为数据源“ org.apache.spark.sql.json.DefaultSource”
class DefaultSource extends RelationProvider with SchemaRelationProvider{

  def createRelation(
                      sqlContext: SQLContext,
                      parameters: Map[String, String],
                      schema: StructType): BaseRelation = {

    val path = parameters.get("path")

    path match {
      case Some(x) => new TextDataSourceRelation(sqlContext,x,schema)
      case _ => throw new IllegalArgumentException("path is required...")
    }
  }

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
    createRelation(sqlContext,parameters,null)
  }

}

紧接着

//这里只实现了全表扫描功能(TableScan),
class TextDataSourceRelation(override val sqlContext: SQLContext,
                             path:String,
                             userSchema:StructType)
  extends BaseRelation with TableScan with Logging{


  override def schema: StructType = {

    if (null != userSchema){
      userSchema
    }else {
	//自定义schema
      StructType(
        StructField("id",LongType,false) ::
          StructField("name",StringType,false) ::
          StructField("gender",StringType,false) ::
          StructField("salary",LongType,false) ::
          StructField("comm",LongType,false) :: Nil
      )

    }


  }

  override def buildScan(): RDD[Row] = {
    logInfo("this is custom buildScan")
    //wholeTextFiles 读取整个文本
    val rdd = sqlContext.sparkContext.wholeTextFiles(path).map(_._2)

    val fieldsSchema = schema.fields

    val rows = rdd.map(fileContent =>{
	  //按行进行切割
      val lines = fileContent.split("\n")
      //每行中以逗号进行切割
      val data = lines.map(_.split(",").map(x=>x.trim)).toSeq
      val result = data.map(x => x.zipWithIndex.map{
        case (value,index) => {
          val columnName = fieldsSchema(index).name
          val castValue = if (columnName.equalsIgnoreCase("gender")){
            if (value.equalsIgnoreCase("0")){
              "man"
            }else if(value.equalsIgnoreCase("1")){
              "woman"
            }else{
              "unknown"
            }
          }else{
            value
          }
          SqlUtil.castTo(castValue,fieldsSchema(index).dataType)
        }

      })

      result.map(x => Row.fromSeq(x))

    })
    rows.flatMap(x =>x)

  }
}

object SqlUtil {
  def castTo(value:String,dataType:DataType)={
    dataType match {
      case _:LongType => value.toLong
      case _:StringType => value
    }
  }
}

main方法

object TestCustomSouce {

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().appName("TextApp").master("local[2]").getOrCreate()

    val df = spark.read.format("com.kzw.bigdata.spark.sql04").option("path","input/custom.txt").load()
    //df.show()
    df.printSchema()


    df.createOrReplaceTempView("customTable")

    val sql = "select * from customTable"
    spark.sql(sql).show()


    val sql2 = "select id,sum(salary) from customTable group by id"
    spark.sql(sql2).show()

    spark.stop()

  }

}

结果显示:

root
 |-- id: long (nullable = false)
 |-- name: string (nullable = false)
 |-- gender: string (nullable = false)
 |-- salary: long (nullable = false)
 |-- comm: long (nullable = false)

+---+-------+-------+------+------+
| id|   name| gender|salary|  comm|
+---+-------+-------+------+------+
|101|zhansan|    man| 10000|200000|
|102|   lisi|    man|150000|250000|
|103| wangwu|  woman|  3000|     5|
|104|zhaoliu|unknown|   500|     6|
|102|   lisi|    man|150000|250000|
+---+-------+-------+------+------+

+---+-----------+
| id|sum(salary)|
+---+-----------+
|103|       3000|
|104|        500|
|101|      10000|
|102|     300000|
+---+-----------+
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

冬瓜螺旋雪碧

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值