在Spark SQL中使用自定义聚集函数
这篇博客是笔者在进行创新实训课程项目时所做工作的回顾。对于该课程项目所有的工作记录,读者可以参阅下面的链接。
参考资料
感谢下面的资料对我这篇博客撰写过程的帮助:
- 感谢B站的这门网课让我迅速了解Spark SQL:https://www.bilibili.com/video/BV1QE411p7T4?p=1
- 感兴趣的读者也可以查看官方文档:https://spark.apache.org/docs/latest/sql-programming-guide.html
- 从这个讨论中了解Spark SQL支持的语法范围:https://www.zhihu.com/question/34569764
问题背景
在学习这个项目时,我对 logvision 项目下的 spark/src/main/scala/streaming.scala 文件有过一点研究。我在想,streaming.scala 文件里,作者通过 url 是否异常,来将一个访问请求界定为善意或者恶意。但是实际上,我们可以考虑更多的因素。比如说,我编写了一个爬虫,只要程序得当,爬虫是可以一直发送正常的 url 请求的,但与此同时爬虫高强度的爬取也会给目标网站的服务器带来负担。这时,我们就可以说,虽然爬虫的请求url都是正常的,但是由于它的访问频率过高,我们仍然可以把爬虫的请求界定为“恶意请求”。
在这篇博客中,我想要略微改造一下streaming.scala文件,在判断 url 是否正常的基础上,再判断一下对应 ip 地址的200状态比。在这篇博文中,我这样定义200状态比:
一个ip地址发送请求的200状态比,指的是这个ip地址发出的所有请求中,返回的状态码为200的请求所占的比例。这个比例越高,代表这个ip地址发出的请求大都得到了有效响应,这个ip地址发出的请求就越偏向正常。
在改造后的程序中,一个请求是正常的,当且仅当它的 url 请求是正常的,并且它对应的 ip 地址的200状态比高于一个阈值(比如0.8)。否则,这个请求就被判断为异常的,如下图所示。我们这篇博客,就是要按照这个标准把日志中的请求分成正常、异常两类。在解决问题的过程中,我们就需要用到自定义聚集函数。
在Spark中写SQL
接下来,我们讨论怎么样在scala文件中,使用Spark的数据库查询功能。
在Spark中,分布式环境下的数据查询可以说是Spark进行大数据处理的基石。Spark的机器学习库等工具,大都需要在查询出来的DataFrame上进行操作。所以,我们接下来就学习一下怎么使用Spark进行数据查询,尤其是采用SQL语句进行查询。
如果你查看官方文档时感到有点头痛,我向你推荐B站上的一门网课,这位老师讲的是真的细致,并且现场编程,很快就可以让你理解大致的编程思路。这篇文章就是大量地参考这门网课写出来的。
在Spark中使用SQL语句来查询,一般而言有两种语法风格。一种是使用Spark中原生的方法来进行查询,一种是使用sql()函数来内嵌SQL字符串进行查询。比如说,对于这个视频中提到的例子,对于两个DataFrame(你可以理解为两张表),如下所示:
第一个DataFrame:df1
id | name | nation |
---|---|---|
1 | LaoZhao | china |
2 | LaoDuan | usa |
第二个DataFrame:df2
ename | cname |
---|---|
china | 中国 |
usa | 美国 |
如果要求我们使用SQL的 inner join 语句对这两张表进行联合查询,那么有SQL基础的读者一定会认为这挺容易的。下面是参考的SQL语句:
SELECT id, name, nation, cname FROM df1 INNER JOIN df2 ON nation = ename;
下面我们给出Spark中的两个风格的查询语句。是不是觉得很亲切?
风格1:Spark中原生的方法
val result: DataFrame = df1.join(df2, $"nation" === $"ename")
风格2:使用sql()函数来内嵌SQL字符串
df1.createTempView("v_users")
df2.createTempView("v_nations")
val result: DataFrame = spark.sql("SELECT id, name, nation, cname FROM v_users INNER JOIN v_nations ON nation = ename")
其中,spark是一个SparkSession类型的变量,表示一次查询会话。我们发现,风格2的方案对比较了解SQL的读者而言,更容易上手一些,因为几乎不用学习新的函数。下面的篇幅里我们将以风格2为主。
Spark支持的SQL语法十分丰富,这个讨论里列举了其中支持的SQL语法范围。Spark是默认支持GROUP BY和常用的聚集函数,比如count(),countDistinct(),avg(),max(),min() 等。然而,Spark还支持用户自定义函数,这极大地丰富了用户的查询自由度。
UDF函数与UDAF函数
UDF(User Defined Function)函数与UDAF(User Defined Aggregate Function)函数是Spark在常规的聚集函数之外支持的特性。具体来说,UDF函数是允许用户自定义“一对一”函数——对每一条输入的记录都返回一个结果;UDAF函数是允许用户自定义“多对一”函数——对输入的多条记录返回一个结果,也就是用户自定义的聚集函数。
在这个视频里,老师提到了如何自定义一个UDF函数。我们把其中的部分代码摘录下来,用来明确UDF函数自定义的姿势:
// 自定义一个UDF函数,并注册。这个函数输入Long类型的ipNum,返回province(省份名)
spark.udf.register("ip2Province", (ipNum: Long) => {
// 下面的函数细节不必关心,我们知道如何自定义UDF函数就好
val ipRulesInExecutor: Array[(Long, Long, String)] = broadcastRef.value
val index = MyUtils.binarySearch(ipRulesInExecutor, ipNum)
var province = "未知"
if(index != -1) {
province = ipRulesInExecutor(index)._3
}
province
})
// 执行SQL,spark是SparkSession类型,使用刚定义好的ip2Province函数
spark.sql("SELECT ip2Province(ip_num) province, COUNT(*) counts FROM v_log GROUP BY province ORDER BY counts DESC")
而UDAF函数就不太一样,它也需要定义并注册一个自定义函数,但是用户要自定义UDAF函数,就必须先要实现一个UserDefinedAggregateFunction接口。在下一节中,我们结合博客中提出的200请求比统计的问题,来讲述一下如何在Spark中自定义一个UDAF函数。
在Spark中自定义UDAF函数
我们强烈建议读者先观看这个视频,了解一下UDAF函数到底应该怎么写,这对读者理解本节内容有很大的帮助。
首先,让我们假定一个叫 nmlpct(status_code) 的函数,能够计算出状态码 status_code 中200的占比情况,那么我们要计算每个 ip 地址的200状态比,只需要下面的SQL语句就可以了:
SELECT host, nmlpct(status_code) normal_percent FROM ori_result GROUP BY host
下面的语句,就是在筛选出所有200状态比大于0.8的请求条目:
val minStandard = 0.8
sparkSession
.sql("SELECT host, nmlpct(status_code) normal_percent FROM ori_result GROUP BY host HAVING nmlpct(status_code) > " + minStandard)
.createOrReplaceTempView("nml_pct_result")
这样,我们就可以兼顾200状态比的要求和URL字符串正常的要求,过滤出善意与恶意的请求了:
// 正常请求结果
val goodResult = sparkSession.sql("SELECT * FROM ori_result WHERE prediction = 0.0 AND host IN (SELECT host FROM nml_pct_result)")
// 异常请求结果
val badResult = sparkSession.sql("SELECT * FROM ori_result WHERE prediction = 1.0 OR host NOT IN (SELECT host FROM nml_pct_result)")
那么,我们怎么自定义nmlpct这个函数呢?我们按照前面讲的这个视频的做法,先定义一个 NormalPercent 函数,它实现了 UserDefinedAggregateFunction 接口:
class NormalPercent extends UserDefinedAggregateFunction {
// 聚合函数的输入数据结构
override def inputSchema: StructType = StructType(List(
StructField("stcode", StringType)
))
// 缓存区数据结构
override def bufferSchema: StructType = StructType(List(
StructField("nml_count", LongType),
StructField("ttl_count", LongType)
))
// 聚合函数返回值数据结构
override def dataType: DataType = DoubleType
// 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
override def deterministic: Boolean = true
// 初始化缓冲区
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
// 给聚合函数传入一条新数据进行处理
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getLong(1) + input.getString(0) match {
case "200" => 1L
case _ => 0L
}
buffer(1) = buffer.getLong(1) + 1L
}
// 合并聚合函数缓冲区
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// 计算最终结果
override def evaluate(buffer: Row): Double = {
var normal_pct = 1
if(buffer.getLong(1) != 0){
buffer.getLong(0).toDouble / buffer.getLong(1).toDouble
}
normal_pct
}
}
然后,我们用上面的 NormalPercent 类来自定义 nmlpct 函数:
val result = learning.applying("/home/logv/IDModel", rdd.toDF())
.select("host", "rfc931", "username", "timestamp", "req_method", "url", "protocol", "status_code", "bytes", "probability", "prediction")
// 统计200状态所占比例
result.createOrReplaceTempView("ori_result")
val nmlpct = new NormalPercent
sparkSession.udf.register("nmlpct", nmlpct)
这样,我们就成功地自定义了nmlpct这个UDAF函数啦。把上面的代码综合一下,就解决了博客中提到的200请求比的问题。