SPARK SQL中自定义udf,udaf函数统计uv(使用bitmap)
在实际工作中统计uv时,一般会使用count(distinct userId)的方式去统计人数,但这样效率不高,假设你是统计多个维度的数据,当某天你想要上卷维度,此时又需要从原始层开始统计,如果数据量大的时候将会耗费很多时间,此时便可以使用最细粒度的聚合结果进行上卷统计,即需要自定义聚合函数进行统计,将bitmap序列化为一个字节数组。
1)一次聚合
package org.shydow.UDF
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
import org.roaringbitmap.RoaringBitmap
/**
* @author shydow
* @date 2021/12/13 22:55
*/
class BitmapGenUDAF extends Aggregator[Int, Array[Byte], Array[Byte]] {
override def zero: Array[Byte] = {
// 构造一个空的bitmap
val bm: RoaringBitmap = RoaringBitmap.bitmapOf()
// 将bitmap序列化为字节数组
BitmapUtil.serBitmap(bm)
}
override def reduce(b: Array[Byte], a: Int): Array[Byte] = {
// 将buff反序列化为bitmap
val bitmap: RoaringBitmap = BitmapUtil.deSerBitmap(b)
bitmap.add(a)
BitmapUtil.serBitmap(bitmap)
}
override def merge(b1: Array[Byte], b2: Array[Byte]): Array[Byte] = {
val bitmap1: RoaringBitmap = BitmapUtil.deSerBitmap(b1)
val bitmap2: RoaringBitmap = BitmapUtil.deSerBitmap(b2)
bitmap1.or(bitmap2)
BitmapUtil.serBitmap(bitmap1)
}
override def finish(reduction: Array[Byte]): Array[Byte] = reduction
override def bufferEncoder: Encoder[Array[Byte]] = Encoders.BINARY
override def outputEncoder: Encoder[Array[Byte]] = Encoders.BINARY
}
package org.shydow.UDF
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import org.roaringbitmap.RoaringBitmap
/**
* @author shydow
* @date 2021/12/13 22:45
*/
object BitmapUtil {
/**
* 序列化bitmap
*/
def serBitmap(bm: RoaringBitmap): Array[Byte] = {
val stream = new ByteArrayOutputStream()
val dataOutput = new DataOutputStream(stream)
bm.serialize(dataOutput)
stream.toByteArray
}
/**
* 反序列bitmap
*/
def deSerBitmap(bytes: Array[Byte]): RoaringBitmap = {
val bm: RoaringBitmap = RoaringBitmap.bitmapOf()
val stream = new ByteArrayInputStream(bytes)
val inputStream = new DataInputStream(stream)
bm.deserialize(inputStream)
bm
}
}
package org.shydow.UDF
import org.apache.spark.sql.{DataFrame, SparkSession, TypedColumn}
import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}
import org.roaringbitmap.RoaringBitmap
/**
* @author shydow
* @date 2021/12/13 22:25
*/
object TestBehaviorAnalysis {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder()
.appName("analysis")
.master("local[*]")
.getOrCreate()
spark.sparkContext.setLogLevel("WARN")
import spark.implicits._
val schema = StructType(Seq(
StructField("id", LongType),
StructField("eventType", StringType),
StructField("code", StringType),
StructField("timestamp", LongType))
)
val frame: DataFrame = spark.read.schema(schema).csv("data/OrderLog.csv")
frame.createOrReplaceTempView("order_log")
/**
* 使用distinct count 计算uv
*/
spark.sql(
s"""
|select
| eventType,
| count(1) as pv,
| count(distinct id) as uv
|from order_log
|group by eventType
|""".stripMargin).show()
/**
* 自定义UDAF计算uv
*/
import org.apache.spark.sql.functions.udaf
spark.udf.register("gen_bitmap", udaf(new BitmapGenUDAF)) // 这个函数出来的是字节数组,如果要计算具体的基数得写一个udf
def card(byteArray: Array[Byte]): Int = {
val bitmap: RoaringBitmap = BitmapUtil.deSerBitmap(byteArray)
bitmap.getCardinality
}
spark.udf.register("get_card", card _)
spark.sql(
s"""
|select
| eventType,
| count(1) as pv,
| gen_bitmap(id) as uv_arr,
| get_card(gen_bitmap(id)) as uv
|from order_log
|group by eventType
|""".stripMargin).show()
spark.close()
}
}
2)上卷聚合
package org.shydow.UDF
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
import org.roaringbitmap.RoaringBitmap
/**
* @author shydow
* @date 2021/12/14 8:36
*/
class BitmapOrMergeUDAF extends Aggregator[Array[Byte], Array[Byte], Array[Byte]]{
override def zero: Array[Byte] = {
val bitmap: RoaringBitmap = RoaringBitmap.bitmapOf()
BitmapUtil.serBitmap(bitmap)
}
override def reduce(b: Array[Byte], a: Array[Byte]): Array[Byte] = {
val bitmap1: RoaringBitmap = BitmapUtil.deSerBitmap(b)
val bitmap2: RoaringBitmap = BitmapUtil.deSerBitmap(a)
bitmap1.or(bitmap2)
BitmapUtil.serBitmap(bitmap1)
}
override def merge(b1: Array[Byte], b2: Array[Byte]): Array[Byte] = {
val bitmap1: RoaringBitmap = BitmapUtil.deSerBitmap(b1)
val bitmap2: RoaringBitmap = BitmapUtil.deSerBitmap(b2)
bitmap1.or(bitmap2)
BitmapUtil.serBitmap(bitmap1)
}
override def finish(reduction: Array[Byte]): Array[Byte] = reduction
override def bufferEncoder: Encoder[Array[Byte]] = Encoders.BINARY
override def outputEncoder: Encoder[Array[Byte]] = Encoders.BINARY
}
package org.shydow.UDF
import org.apache.spark.sql.{DataFrame, SparkSession, TypedColumn}
import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}
import org.roaringbitmap.RoaringBitmap
/**
* @author shydow
* @date 2021/12/13 22:25
*/
object TestBehaviorAnalysis {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder()
.appName("analysis")
.master("local[*]")
.getOrCreate()
spark.sparkContext.setLogLevel("WARN")
import spark.implicits._
val schema = StructType(Seq(
StructField("id", LongType),
StructField("eventType", StringType),
StructField("code", StringType),
StructField("timestamp", LongType))
)
val frame: DataFrame = spark.read.schema(schema).csv("data/OrderLog.csv")
frame.createOrReplaceTempView("order_log")
/**
* 使用distinct count 计算uv
*/
spark.sql(
s"""
|select
| eventType,
| code,
| count(1) as pv,
| count(distinct id) as uv
|from order_log
|where code is not null
|group by eventType, code
|""".stripMargin).show()
/**
* 自定义UDAF计算uv
*/
import org.apache.spark.sql.functions.udaf
spark.udf.register("gen_bitmap", udaf(new BitmapGenUDAF)) // 这个函数出来的是字节数组,如果要计算具体的基数得写一个udf
def card(byteArray: Array[Byte]): Int = {
val bitmap: RoaringBitmap = BitmapUtil.deSerBitmap(byteArray)
bitmap.getCardinality
}
spark.udf.register("get_card", card _)
val res: DataFrame = spark.sql(
s"""
|select
| eventType,
| code,
| count(1) as pv,
| gen_bitmap(id) as uv_arr,
| get_card(gen_bitmap(id)) as uv
|from order_log
|where code is not null
|group by eventType, code
|""".stripMargin)
res.createTempView("dws_stat")
spark.udf.register("bitmapOr", udaf(new BitmapOrMergeUDAF))
spark.sql(
s"""
|select
| eventType,
| sum(pv) as total_pv,
| bitmapOr(uv_arr),
| get_card(bitmapOr(uv_arr)) as total_uv
|from dws_stat
|group by eventType
|""".stripMargin).show()
spark.close()
}
}