在Spark上通过自定义RDD访问HBase

这里介绍一个在Spark上使用自定义RDD获取HBase数据的方案。

这个方案的基础是我们的HBase表的行键设计。行键设计大概是这样子的:标签ID+时间戳+随机码。平时的需求主要是导出指定标签在某个时间范围内的全部记录。根据需求和行键设计确定下实现的大方向:使用行键中的时间戳进行partition并界定startRow和stopRow来缩小查询范围,使用HBase API创建RDD获取数据,在获取的数据的基础上使用SparkSQL来执行灵活查询。
创建Partition

这里我们要自定义的RDD主要的功能就是获取源数据,所以需要自定义实现Partition类:

private[campaign] class QueryPartition(idx: Int, val start: Long, val stop: Long) extends Partition {
override def index: Int = idx
}
1
2
3

private[campaign] class QueryPartition(idx: Int, val start: Long, val stop: Long) extends Partition {
override def index: Int = idx
}

前面我们说过,主要是依赖行键中的时间戳来进行partition的,所以在自定义的QueryPartition类中保留了两个长整型构造参数start和stop来表示起止时间。剩下的构造参数idx作用是标记分区的索引。

这样,自定义RDD中的getPartitions()方法该如何实现也就很清楚了:

override protected def getPartitions: Array[Partition] = {
var tmp = unit.startTime
var i = 0
val partitions = ArrayBufferPartition
while (tmp < unit.stopTime) {
val stopTime = tmp + TimeUnit.HOURS.toMillis(1L)
partitions += new QueryPartition(i, tmp, stopTime)
i = i + 1
tmp = stopTime
}
partitions.toArray
}
1
2
3
4
5
6
7
8
9
10
11
12

override protected def getPartitions: Array[Partition] = {
var tmp = unit.startTime
var i = 0
val partitions = ArrayBufferPartition
while (tmp < unit.stopTime) {
val stopTime = tmp + TimeUnit.HOURS.toMillis(1L)
partitions += new QueryPartition(i, tmp, stopTime)
i = i + 1
tmp = stopTime
}
partitions.toArray
}

在上面代码中的第六行可以看到getPartitions()方法是按每小时一个区间进行partition的。

代码中的unit是一个查询单元,封装了一些必要的查询参数,包括存储数据的表、要查询的标签ID以及起止时间。大致是这样的:

class QueryUnit(val table: String,
val tagId: String,
val startTime: Long,
val stopTime: Long) extends Serializable {
}
1
2
3
4
5

class QueryUnit(val table: String,
val tagId: String,
val startTime: Long,
val stopTime: Long) extends Serializable {
}

注意,QueryUnit这个类需要实现Serializable接口。
查询HBase

因为要实现灵活查询的需求,所以需要将HBase表中符合需求的数据的所有列都取出来。我们可以考虑使用List[Map[String, String]]这样一种结构来暂时保存数据。根据这个需求实现的自定义HBaseClient的代码如下:

import com.typesafe.config.ConfigFactory
import org.apache.hadoop.hbase._
import org.apache.hadoop.hbase.client._
import org.apache.hadoop.hbase.filter.Filter
import org.apache.hadoop.hbase.util.Bytes
import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.collection.mutable.ListBuffer

object HBaseClient {
private val config = ConfigFactory.load()
private val conn = getConnection

/**
   * 扫描HBase并返回结果
   */
def scan(tableName: String, filter: Filter, startRow: String, stopRow: String): List[Map[String, String]] = {
	val s = buildScan(filter, startRow, stopRow)
	val t = conn.getTable(TableName.valueOf(tableName))
	scan(t, s)
}

/**
   * 执行扫描
   */
private def scan(table: Table, scan: Scan): List[Map[String, String]] = {
	val scanner = table.getScanner(scan)
	val ite = scanner.iterator()
	val result = new ListBuffer[Map[String, String]]
	while (ite.hasNext) {
		val map = new mutable.ListMap[String, String]
		ite.next().listCells().foreach(c => map += readCell(c))
		result += map.toMap
	}
	result.toList
}
  
/**
   * 读取单元格
   */
private def readCell(cell: Cell) = {
	val qualifier = Bytes.toString(CellUtil.cloneQualifier(cell))
	val value = Bytes.toString(CellUtil.cloneValue(cell))
	(qualifier, value)
}
  
/**
   * 构建Scan实例
   */
private def buildScan(filter: Filter, startRow: String, stopRow: String): Scan = {
	val scan = new Scan()
	scan.setMaxVersions
	scan.setCaching(2000)
	scan.setCacheBlocks(false)
	if (null != filter) scan.setFilter(filter)
	if (null != startRow) scan.setStartRow(Bytes.toBytes(startRow))
	if (null != stopRow) scan.setStartRow(Bytes.toBytes(stopRow))
	scan
}
  
/**
   * 获取连接
   */
  private def getConnection: Connection = {
    val conf = HBaseConfiguration.create()
    conf.set(HConstants.ZOOKEEPER_QUORUM, config.getString("hbase.zookeeper"))
    conf.set(HConstants.ZOOKEEPER_ZNODE_PARENT, config.getString("hbase.pnode"))
    ConnectionFactory.createConnection(conf)
 }

}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

import com.typesafe.config.ConfigFactory
import org.apache.hadoop.hbase._
import org.apache.hadoop.hbase.client._
import org.apache.hadoop.hbase.filter.Filter
import org.apache.hadoop.hbase.util.Bytes
import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.collection.mutable.ListBuffer

object HBaseClient {
private val config = ConfigFactory.load()
private val conn = getConnection

/**
   * 扫描HBase并返回结果
   */
def scan(tableName: String, filter: Filter, startRow: String, stopRow: String): List[Map[String, String]] = {
	val s = buildScan(filter, startRow, stopRow)
	val t = conn.getTable(TableName.valueOf(tableName))
	scan(t, s)
}

/**
   * 执行扫描
   */
private def scan(table: Table, scan: Scan): List[Map[String, String]] = {
	val scanner = table.getScanner(scan)
	val ite = scanner.iterator()
	val result = new ListBuffer[Map[String, String]]
	while (ite.hasNext) {
		val map = new mutable.ListMap[String, String]
		ite.next().listCells().foreach(c => map += readCell(c))
		result += map.toMap
	}
	result.toList
}
  
/**
   * 读取单元格
   */
private def readCell(cell: Cell) = {
	val qualifier = Bytes.toString(CellUtil.cloneQualifier(cell))
	val value = Bytes.toString(CellUtil.cloneValue(cell))
	(qualifier, value)
}
  
/**
   * 构建Scan实例
   */
private def buildScan(filter: Filter, startRow: String, stopRow: String): Scan = {
	val scan = new Scan()
	scan.setMaxVersions
	scan.setCaching(2000)
	scan.setCacheBlocks(false)
	if (null != filter) scan.setFilter(filter)
	if (null != startRow) scan.setStartRow(Bytes.toBytes(startRow))
	if (null != stopRow) scan.setStartRow(Bytes.toBytes(stopRow))
	scan
}
  
/**
   * 获取连接
   */
  private def getConnection: Connection = {
    val conf = HBaseConfiguration.create()
    conf.set(HConstants.ZOOKEEPER_QUORUM, config.getString("hbase.zookeeper"))
    conf.set(HConstants.ZOOKEEPER_ZNODE_PARENT, config.getString("hbase.pnode"))
    ConnectionFactory.createConnection(conf)
 }

}

在代码中使用了typesafe的Config类来记录一些信息,比如zookeeper连接信息。

查询HBase这块儿没什么好说的。继续好了。
自定义RDD

前面的两节为自定义实现RDD类做了一些铺垫,包括进行partition的方式,以及一个查询hbase的工具类HBaseClient。实际上我们已经完成实现了自定义RDD的一个抽象方法getPartitions。

自定义RDD需要继承Spark的一个抽象类RDD。

在继承抽象类RDD的同时,需要为它提供两个构造参数,一个是SparkContext实例,一个是父RDD列表。我们要自定义的RDD的是用来获取源数据的,没有父RDD,所以父RDD列表可以直接设置为Nil。

抽象类RDD还有两个抽象方法需要实现,分别是getPartitions()和compute()。getPartitions()方法用来对原始的任务进行分片,可以将原始任务切割成不同的partition,以便进行分布式处理。compute()方法则是实现了对切割出的partition进行处理的逻辑。

getPartitions()方法的实现前面已经提过了,现在看看compute()方法的实现:

/**
* 查询HBase
*/
private def query(part: QueryPartition) = {
val filter = new PrefixFilter(Bytes.toBytes(unit.tagId))
val startRow = unit.tagId+part.start
val stopRow = unit.tagId+part.stop
val results = HBaseClient.scan(unit.table, filter, startRow, stopRow)
results.map(e => JSONObject(e).toString())
}

/**
* 执行计算
*/
override def compute(split: Partition, context: TaskContext): Iterator[String] = {
val part = split.asInstanceOf[QueryPartition]
val results = query(part)
new InterruptibleIterator(context, results.iterator)
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

/**
* 查询HBase
*/
private def query(part: QueryPartition) = {
val filter = new PrefixFilter(Bytes.toBytes(unit.tagId))
val startRow = unit.tagId+part.start
val stopRow = unit.tagId+part.stop
val results = HBaseClient.scan(unit.table, filter, startRow, stopRow)
results.map(e => JSONObject(e).toString())
}

/**
* 执行计算
*/
override def compute(split: Partition, context: TaskContext): Iterator[String] = {
val part = split.asInstanceOf[QueryPartition]
val results = query(part)
new InterruptibleIterator(context, results.iterator)
}

可以看到,compute()方法就是在getPartitions()方法创建的时间区间QueryPartition上对HBase中的表进行查询,并将查询出的结果封装成json字符串列表。
编写驱动类

至此,工作已经完成了大半,可以看看驱动类是怎么写的了:

import com.typesafe.config.ConfigFactory
import org.apache.hadoop.io.compress.GzipCodec
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext

object HBaseQueryDriver {
def main(args: Array[String]): Unit = {
SparkUtils.runJob(start, args, SparkUtils.initSparkContext(“CustomQueryJob”), “CustomQueryJob”)
}

def start(args: Array[String], sc: SparkContext) = {
	val Array(topic, tagId, startTime, stopTime, sql) = args
	new HBaseQueryDriver(topic, tagId, startTime.toLong, stopTime.toLong, sql).query(sc)
}

}

class HBaseQueryDriver(table: String, tagId: String, startTime: Long, stopTime: Long, sql: String) {
private val config = ConfigFactory.load()
private val outPath = config.getString(“out.path”)
private val tmpPath = config.getString(“tmp.path”)
def query(sc: SparkContext) = {
val unit = new QueryUnit(table, tagId, startTime, stopTime)
val dataRDD = new QueryRDD(sc, unit)
val sqlContext = new SQLContext(sc)
sqlContext.read.json(dataRDD).registerTempTable(table)
Hdfs.delete(outPath)
Hdfs.delete(tmpPath)
sqlContext.sql(sql).map(row => row.mkString(",")).saveAsTextFile(tmpPath)
sc.textFile(tmpPath).coalesce(1, true).saveAsTextFile(outPath, classOf[GzipCodec])
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

import com.typesafe.config.ConfigFactory
import org.apache.hadoop.io.compress.GzipCodec
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext

object HBaseQueryDriver {
def main(args: Array[String]): Unit = {
SparkUtils.runJob(start, args, SparkUtils.initSparkContext(“CustomQueryJob”), “CustomQueryJob”)
}

def start(args: Array[String], sc: SparkContext) = {
	val Array(topic, tagId, startTime, stopTime, sql) = args
	new HBaseQueryDriver(topic, tagId, startTime.toLong, stopTime.toLong, sql).query(sc)
}

}

class HBaseQueryDriver(table: String, tagId: String, startTime: Long, stopTime: Long, sql: String) {
private val config = ConfigFactory.load()
private val outPath = config.getString(“out.path”)
private val tmpPath = config.getString(“tmp.path”)
def query(sc: SparkContext) = {
val unit = new QueryUnit(table, tagId, startTime, stopTime)
val dataRDD = new QueryRDD(sc, unit)
val sqlContext = new SQLContext(sc)
sqlContext.read.json(dataRDD).registerTempTable(table)
Hdfs.delete(outPath)
Hdfs.delete(tmpPath)
sqlContext.sql(sql).map(row => row.mkString(",")).saveAsTextFile(tmpPath)
sc.textFile(tmpPath).coalesce(1, true).saveAsTextFile(outPath, classOf[GzipCodec])
}
}

驱动类的任务是在创建的QueryRDD上使用SparkSQL执行查询,并将查询结果保存到HDFS上。

代码中的SparkUtil只是将经常使用的初始化SparkContext以及执行Spark任务的行为封装了一下,是这样实现的:

object SparkUtils extends Logging {
/**
* 执行作业
/
def runJob(job: (Array[String], SparkContext) => Unit,
args: Array[String],
context: SparkContext,
errorMsg: String = “任务执行失败”): Unit = {
val startTime = System.currentTimeMillis()
try {
job(args, context)
} catch {
case ex: Throwable => ex.printStackTrace(); throw ex
} finally {
log.info(“Cost Time: %s”.format(System.currentTimeMillis() - startTime))
}
}
/
*
* 初始化spark上下文
/
def initSparkContext(appName: String, master: String = null): SparkContext = {
val sparkConf = new SparkConf().setAppName(appName)
sparkConf.set(“spark.files.userClassPathFirst”, “true”)
if (null == sparkConf.get(“spark.master”, null) && null == master)
sparkConf.set(“spark.master”, "local[
]")
if (null != master) sparkConf.set(“spark.master”, master)
new SparkContext(sparkConf)
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

object SparkUtils extends Logging {
/**
* 执行作业
/
def runJob(job: (Array[String], SparkContext) => Unit,
args: Array[String],
context: SparkContext,
errorMsg: String = “任务执行失败”): Unit = {
val startTime = System.currentTimeMillis()
try {
job(args, context)
} catch {
case ex: Throwable => ex.printStackTrace(); throw ex
} finally {
log.info(“Cost Time: %s”.format(System.currentTimeMillis() - startTime))
}
}
/
*
* 初始化spark上下文
/
def initSparkContext(appName: String, master: String = null): SparkContext = {
val sparkConf = new SparkConf().setAppName(appName)
sparkConf.set(“spark.files.userClassPathFirst”, “true”)
if (null == sparkConf.get(“spark.master”, null) && null == master)
sparkConf.set(“spark.master”, "local[
]")
if (null != master) sparkConf.set(“spark.master”, master)
new SparkContext(sparkConf)
}
}

好了,就这样!!
–重写 RDD
class QueryRDD(sc: SparkContext, tableName: String, startRow: String, endRow: String, splitKeys: Array[String]) extends RDD[Map[String,String]](sc, Nil)
{

#重写该方法用于计算每一个 partition
override def compute(split: Partition, context: TaskContext): Iterator[Map[String,String]] =
{
val part = split.asInstanceOf[QueryPartition]
val results = query(part)
new InterruptibleIterator(context, results.iterator)
}

#重写该方法用于获取 partition
override protected def getPartitions: Array[Partition] =
{
val partitions = ArrayBufferPartition
for (splitKey <- splitKeys)
{
partitions += new QueryPartition(splitKey)
}
partitions.toArray
}

private def query(partition: QueryPartition) =
{
val splitKey = partition.split
val filter = null #该参数可以不为 null,即可在 scan 的同时进行 filter
val start = splitKey + startRow
val end = splitKey + endRow
HBaseClient.scan(tableName, filter, start, end)
}
}

#实现自己的 partition
class QueryPartition(splitKey: String) extends Partition
{
def split: String = splitKey

override def index: Int = splitKey.substring(0, 3).toInt

override def hashCode(): Int = index
}

以上是重写 RDD,hbase 的具体 scan 操作,在我上面的链接里可以找到,我照搬了过来.但是要注意他的 58 行,要把 startRow 改成 stopRow,不然的话其他代码写得再好都白费啦

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值