Spark自定义RDD访问HBase

http://www.zhyea.com/2017/06/21/visit-hbase-with-custom-spark-rdd.html

这里介绍一个在Spark上使用自定义RDD获取HBase数据的方案。

这个方案的基础是我们的HBase表的行键设计。行键设计大概是这样子的:标签ID+时间戳+随机码。平时的需求主要是导出指定标签在某个时间范围内的全部记录。根据需求和行键设计确定下实现的大方向:使用行键中的时间戳进行partition并界定startRow和stopRow来缩小查询范围,使用HBase API创建RDD获取数据,在获取的数据的基础上使用SparkSQL来执行灵活查询。

创建Partition

这里我们要自定义的RDD主要的功能就是获取源数据,所以需要自定义实现Partition类:

private[campaign] class QueryPartition(idx: Int, val start: Long, val stop: Long) extends Partition {
	override def index: Int = idx
}

前面我们说过,主要是依赖行键中的时间戳来进行partition的,所以在自定义的QueryPartition类中保留了两个长整型构造参数start和stop来表示起止时间。剩下的构造参数idx作用是标记分区的索引。

这样,自定义RDD中的getPartitions()方法该如何实现也就很清楚了:

override protected def getPartitions: Array[Partition] = {
	var tmp = unit.startTime
	var i = 0
	val partitions = ArrayBuffer[Partition]()
	while (tmp < unit.stopTime) {
		val stopTime = tmp + TimeUnit.HOURS.toMillis(1L)
		partitions += new QueryPartition(i, tmp, stopTime)
		i = i + 1
		tmp = stopTime
	}
	partitions.toArray
}

在上面代码中的第六行可以看到getPartitions()方法是按每小时一个区间进行partition的。

代码中的unit是一个查询单元,封装了一些必要的查询参数,包括存储数据的表、要查询的标签ID以及起止时间。大致是这样的:

class QueryUnit(val table: String,
                val tagId: String,
                val startTime: Long,
                val stopTime: Long) extends Serializable {
}

注意,QueryUnit这个类需要实现Serializable接口。

查询HBase

因为要实现灵活查询的需求,所以需要将HBase表中符合需求的数据的所有列都取出来。我们可以考虑使用List[Map[String, String]]这样一种结构来暂时保存数据。根据这个需求实现的自定义HBaseClient的代码如下:

import com.typesafe.config.ConfigFactory
import org.apache.hadoop.hbase._
import org.apache.hadoop.hbase.client._
import org.apache.hadoop.hbase.filter.Filter
import org.apache.hadoop.hbase.util.Bytes
import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.collection.mutable.ListBuffer


object HBaseClient {
	private val config = ConfigFactory.load()
	private val conn = getConnection
	  
	/**
	   * 扫描HBase并返回结果
	   */
	def scan(tableName: String, filter: Filter, startRow: String, stopRow: String): List[Map[String, String]] = {
		val s = buildScan(filter, startRow, stopRow)
		val t = conn.getTable(TableName.valueOf(tableName))
		scan(t, s)
	}
	
	/**
	   * 执行扫描
	   */
	private def scan(table: Table, scan: Scan): List[Map[String, String]] = {
		val scanner = table.getScanner(scan)
		val ite = scanner.iterator()
		val result = new ListBuffer[Map[String, String]]
		while (ite.hasNext) {
			val map = new mutable.ListMap[String, String]
			ite.next().listCells().foreach(c => map += readCell(c))
			result += map.toMap
		}
		result.toList
	}
	  
	/**
	   * 读取单元格
	   */
	private def readCell(cell: Cell) = {
		val qualifier = Bytes.toString(CellUtil.cloneQualifier(cell))
		val value = Bytes.toString(CellUtil.cloneValue(cell))
		(qualifier, value)
	}
	  
	/**
	   * 构建Scan实例
	   */
	private def buildScan(filter: Filter, startRow: String, stopRow: String): Scan = {
		val scan = new Scan()
		scan.setMaxVersions
		scan.setCaching(2000)
		scan.setCacheBlocks(false)
		if (null != filter) scan.setFilter(filter)
		if (null != startRow) scan.setStartRow(Bytes.toBytes(startRow))
		if (null != stopRow) scan.setStopRow(Bytes.toBytes(stopRow))
		scan
	}
	  
	/**
	   * 获取连接
	   */
	  private def getConnection: Connection = {
	    val conf = HBaseConfiguration.create()
	    conf.set(HConstants.ZOOKEEPER_QUORUM, config.getString("hbase.zookeeper"))
	    conf.set(HConstants.ZOOKEEPER_ZNODE_PARENT, config.getString("hbase.pnode"))
	    ConnectionFactory.createConnection(conf)
	 }
}

在代码中使用了typesafe的Config类来记录一些信息,比如zookeeper连接信息。

查询HBase这块儿没什么好说的。继续好了。

自定义RDD

前面的两节为自定义实现RDD类做了一些铺垫,包括进行partition的方式,以及一个查询hbase的工具类HBaseClient。实际上我们已经完成实现了自定义RDD的一个抽象方法getPartitions。

自定义RDD需要继承Spark的一个抽象类RDD。

在继承抽象类RDD的同时,需要为它提供两个构造参数,一个是SparkContext实例,一个是父RDD列表。我们要自定义的RDD的是用来获取源数据的,没有父RDD,所以父RDD列表可以直接设置为Nil。

抽象类RDD还有两个抽象方法需要实现,分别是getPartitions()和compute()。getPartitions()方法用来对原始的任务进行分片,可以将原始任务切割成不同的partition,以便进行分布式处理。compute()方法则是实现了对切割出的partition进行处理的逻辑。

getPartitions()方法的实现前面已经提过了,现在看看compute()方法的实现:

  /**
    * 查询HBase
    */
  private def query(part: QueryPartition) = {
    val filter = new PrefixFilter(Bytes.toBytes(unit.tagId))
    val startRow = unit.tagId+part.start
    val stopRow = unit.tagId+part.stop
    val results = HBaseClient.scan(unit.table, filter, startRow, stopRow)
    results.map(e => JSONObject(e).toString())
  }

  /**
    * 执行计算
    */
  override def compute(split: Partition, context: TaskContext): Iterator[String] = {
    val part = split.asInstanceOf[QueryPartition]
    val results = query(part)
    new InterruptibleIterator(context, results.iterator)
  }

可以看到,compute()方法就是在getPartitions()方法创建的时间区间QueryPartition上对HBase中的表进行查询,并将查询出的结果封装成json字符串列表。

编写驱动类
至此,工作已经完成了大半,可以看看驱动类是怎么写的了:

import com.typesafe.config.ConfigFactory
import org.apache.hadoop.io.compress.GzipCodec
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext

object HBaseQueryDriver {
	def main(args: Array[String]): Unit = {
		SparkUtils.runJob(start, args, SparkUtils.initSparkContext("CustomQueryJob"), "CustomQueryJob")
	}
	
	def start(args: Array[String], sc: SparkContext) = {
		val Array(topic, tagId, startTime, stopTime, sql) = args
		new HBaseQueryDriver(topic, tagId, startTime.toLong, stopTime.toLong, sql).query(sc)
	}
}

class HBaseQueryDriver(table: String, tagId: String, startTime: Long, stopTime: Long, sql: String) {
	private val config = ConfigFactory.load()
	private val outPath = config.getString("out.path")
	private val tmpPath = config.getString("tmp.path")
	def query(sc: SparkContext) = {
		val unit = new QueryUnit(table, tagId, startTime, stopTime)
		val dataRDD = new QueryRDD(sc, unit)
		val sqlContext = new SQLContext(sc)
		sqlContext.read.json(dataRDD).registerTempTable(table)
		Hdfs.delete(outPath)
		Hdfs.delete(tmpPath)
		sqlContext.sql(sql).map(row => row.mkString(",")).saveAsTextFile(tmpPath)
		sc.textFile(tmpPath).coalesce(1, true).saveAsTextFile(outPath, classOf[GzipCodec])
	}
}

驱动类的任务是在创建的QueryRDD上使用SparkSQL执行查询,并将查询结果保存到HDFS上。

代码中的SparkUtil只是将经常使用的初始化SparkContext以及执行Spark任务的行为封装了一下,是这样实现的:

object SparkUtils extends Logging {
  /**
    * 执行作业
    */
  def runJob(job: (Array[String], SparkContext) => Unit,
             args: Array[String],
             context: SparkContext,
             errorMsg: String = "任务执行失败"): Unit = {
    val startTime = System.currentTimeMillis()
    try {
      job(args, context)
    } catch {
      case ex: Throwable => ex.printStackTrace(); throw ex
    } finally {
      log.info("Cost Time: %s".format(System.currentTimeMillis() - startTime))
    }
  }
  /**
    * 初始化spark上下文
    */
  def initSparkContext(appName: String, master: String = null): SparkContext = {
    val sparkConf = new SparkConf().setAppName(appName)
    sparkConf.set("spark.files.userClassPathFirst", "true")
    if (null == sparkConf.get("spark.master", null) && null == master)
      sparkConf.set("spark.master", "local[*]")
    if (null != master) sparkConf.set("spark.master", master)
    new SparkContext(sparkConf)
  }
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值