spark 自定义数据源

4 篇文章 0 订阅
2 篇文章 0 订阅
1、创建hbase数据源表
node1> bin/hbase shell

create 'spark_hbase_sql','cf'
put 'spark_hbase_sql','0001','cf:name','zhangsan'
put 'spark_hbase_sql','0001','cf:score','80'
put 'spark_hbase_sql','0002','cf:name','lisi'
put 'spark_hbase_sql','0002','cf:score','60'

2、创建Hbase的数据保存表
bin/hbase shell
create 'spark_hbase_write','cf'

3.代码 编写
package com.alibaba.programApp

import java.util
import java.util.Optional

import com.travel.utils.HbaseTools
import org.apache.hadoop.hbase.TableName
import org.apache.hadoop.hbase.client.{Connection, Put, Result, ResultScanner, Scan, Table}
import org.apache.hadoop.hbase.util.Bytes
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceReader}
import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, WriteSupport}
import org.apache.spark.sql.types.StructType

/**
  * @author cherish
  * @create 2020-04-26 10:49
  */
object HBaseSourceAndSink {
    def main(args: Array[String]): Unit = {
        val conf = new SparkConf()
        conf.setMaster("local[2]").setAppName("sparkSqlSourceAndSink")
        val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()

        //format 需要我们自定义数据源
        val df: DataFrame = spark.read.format("com.travel.programApp.HBaseSource")
            .option("hbase.table.name", "spark_hbase_sql") //我们自己带的一些参数
            .option("cf.cc", "cf:name,cf:score") //定义我们查询hbase的那些列
            .option("schema", "`name` STRING , `score` STRING") //定义我们表的schema 返回的数据是按照循序定义的
            .load

        df.createOrReplaceTempView("sparkHBaseSQL")

        df.printSchema()

        //分析得到的结果数据 , 将结果数据保存到hbase , redis 或者 mysql 中 或者es
        val resultDF: DataFrame = spark.sql("select * from sparkHBaseSQL where score > 70 ")

        resultDF.write.format("com.travel.programApp.HBaseSource")
            .mode(SaveMode.Overwrite)
            .option("hbase.table.name" , "spark_hbase_write")
            .option("cf" , "cf")
            .save()


    }
}

//自定义数据源,实现数据的查询
class  HBaseSource extends DataSourceV2 with ReadSupport with WriteSupport{
    /**
      * 定义我们映射的表的schema
      * @param options
      * @return
      */
    override def createReader(options: DataSourceOptions): DataSourceReader = {
        //从spark.read.format().option() 里面传过来的
        val tableName  :String = options.get("hbase.table.name").get()
        val cfAndCC :String = options.get("cf.cc").get()
        val schema:String =  options.get("schema").get()


        new HBaseDataSourceReader(tableName , cfAndCC , schema)
    }

    override def createWriter(jobId: String, schema: StructType, mode: SaveMode, options: DataSourceOptions): Optional[DataSourceWriter] = {

        //从spark.read.format().option() 里面传过来的
        val tableName  :String = options.get("hbase.table.name").get()
        val family :String = options.get("cf").get()


        Optional.of(new HBaseDataSourceWriter(tableName))
    }
}

class HBaseDataSourceWriter(tableName : String) extends DataSourceWriter {

    /**
      * 将数据保存起来全部依赖这个方法
      * @return
      */
    override def createWriterFactory(): DataWriterFactory[Row] = {
        new HBaseDataWriterFactory(tableName)
    }

    //数据提方法
    override def commit(messages: Array[WriterCommitMessage]): Unit = {

    }

    //放弃数据的插入方法
    override def abort(messages: Array[WriterCommitMessage]): Unit = {

    }
}

class HBaseDataWriterFactory(tableName : String) extends DataWriterFactory[Row] {
    override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = {
        new HBaseDataWriter(tableName)
    }
}

class HBaseDataWriter(tableName :String) extends DataWriter[Row] {

    private val conn = HbaseTools.getHbaseConn
    private val table: Table = conn.getTable(TableName.valueOf(tableName))

    //写入数据
    override def write(record: Row): Unit = {
        val name:String = record.getString(0)
        val score  :  String = record.getString(1)

        val put = new Put("0001".getBytes())

        put.addColumn("cf".getBytes() , "name".getBytes() ,  name.getBytes())
        put.addColumn("cf".getBytes() , "score".getBytes() , score.getBytes())

        table.put(put)

    }

    //数据的提价方法,数据插入完成之后,在这个方法里面进行数据的事务提交
    override def commit(): WriterCommitMessage = {
        //hbase 里面没有事务 , 所以在这里就把 table 和 conn 关闭 , 然后返回null
        table.close()
        conn.close()
        null
    }

    override def abort(): Unit = {

    }
}

class HBaseDataSourceReader(tableName:String , cfAndCC:String , schema:String) extends DataSourceReader {
    override def readSchema(): StructType = {
        StructType.fromDDL(schema)
    }

    override def createDataReaderFactories(): util.List[DataReaderFactory[Row]] = {
        /**
          * new HBaseDateReaderFactory().asInstanceOf[DataReaderFactory[Row] ]:
          * 将 HBaseDateReaderFactory() 转成 DataReaderFactory[Row]  对象
          * Seq[] scala集合
          * as.Java  将 scala seq 集合转成 java 集合
          */
        import scala.collection.JavaConverters._
        Seq(new HBaseDateReaderFactory(tableName , cfAndCC).asInstanceOf[DataReaderFactory[Row]]).asJava
    }
}
class HBaseDateReaderFactory(tableName:String,cfAndCC:String) extends DataReaderFactory[Row] {
    override def createDataReader(): DataReader[Row] = {
        new HBaseDataReader(tableName , cfAndCC);
    }
}

/**
  * 自定义HBaseDateReader 实现dateReader接口
  *
  */
class HBaseDataReader(tableName :String , cfAndCC:String) extends DataReader[Row] {

    var conn: Connection = null
    var table: Table = null

    var scan = new Scan()
    var resultScanner: ResultScanner = null

    /**
      * 使用ProtobufUtil将sparkContext对象序列化成为一个字符串传下来,下面在反序列化
      */

    /**
      * 获取我们hbase的数据就在这
      * @return
      */
    def getIterator: Iterator[Seq[AnyRef]] = {
        conn = HbaseTools.getHbaseConn
        table = conn.getTable(TableName.valueOf(tableName))
        resultScanner = table.getScanner(scan)

        import scala.collection.JavaConverters._

        val iterator: Iterator[Seq[AnyRef]] = resultScanner.iterator().asScala.map(eachResult => {

            /*val cfCCArr: Array[String] = cfAndCC.split(",")
            val family: String = cfCCArr(0).split(":")(0)

            val clumn1: String = cfCCArr(0).split(":")(1)
            val clumn2: String = cfCCArr(2).split(":")(1)*/

            val name: String = Bytes.toString(eachResult.getValue("cf".getBytes(), "name".getBytes()))
            val score: String = Bytes.toString(eachResult.getValue("cf".getBytes(), "score".getBytes()))
            System.out.println("===================================")
            System.out.println(Seq(name, score).toString())
            System.out.println("===================================")

            Seq(name, score)
        })
        iterator

    }

    val data:Iterator[Seq[AnyRef]] = getIterator

    /**
      * 这个方法反复不断的被调用,只要我们查询到了数据,就可以使用next方法一直获取下一条数据
      * @return
      */
    override def next(): Boolean = {
        data.hasNext
    }



    /**
      * 获取到的数据在这个方法里面一条条的解析,解析之后,映射到我们提前定义的表里面去
      * @return
      */
    override def get(): Row ={
        Row.fromSeq(data.next())
    }

    /**
      *
      */
    override def close(): Unit = {
        table.close()
        conn.close()
    }
}
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值