引言
这两天在看structured streaming的时候,打算用开窗函数跑一下 “分组topN” 练练手,结果一上来直接报错:
Non-time-based windows are not supported on streaming DataFrames/Datasets;;
网上查了一下,说是streaming Dataset不支持开窗函数(WTF!),然后查了一下源码,异常是从这抛出的
object UnsupportedOperationChecker {
...
plan.foreachUp { implicit subPlan =>
// Operations that cannot exists anywhere in a streaming plan
subPlan match {
...
// 这意思就是压根就不让写 winSpec 呗...
case Window(_, _, _, child) if child.isStreaming =>
throwError("Non-time-based windows are not supported on streaming DataFrames/Datasets")
...
}
}
没办法,只能自定义实现了…
要想实现在structured streaming中的开窗,其实是用到Stateful操作(这里我用的是flatMapGroupWithState),这个算子基本上和spark streaming中的updateStateByKey类似,就是把之前的聚合结果都存在state中,等每次新数据过来前更新state,然后再对state分组取topN。介绍先到这,直接上代码(Talk is cheap, show me the CODE)
代码
因为是测试,写的比较简陋,大致的意思就是实时统计每个省count前二的城市
import java.sql.Timestamp
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode}
import scala.collection.mutable
object topNStreaming {
def main(args: Array[String]): Unit = {
if (!args.nonEmpty || args.length < 2){
System.err.print(
// 这块不知道为啥没显示成字符。。可以直接看面
"""
| Usage <brokers> <topics>
| brokers: a list of brokers | Example: 192.168.0.0:9092,192.168.0.1:9092,..
| topics: a list of topics | Example: test,test1
""".stripMargin)
System.exit(1)
}
val Array(brokers, topics) =args
// construct sparkSession instance
val spark = SparkSession.builder()
.appName("topNStreaming")
.master("local[*]")
.getOrCreate()
spark.sparkContext.setLogLevel("warn")
import spark.implicits._
// reading kafka stream
val logs = spark.readStream
.format("org.apache.spark.sql.kafka010.KafkaSourceProvider")
.option("kafka.bootstrap.servers", brokers)
.option("subscribe", topics)
.option("startingOffsets", "latest")
.load().as[log]
.map(_.value.split(" ")).map(v => Events(getTimepstamp(v(0)), v(1),v(2),v(3).toLong,v(4).toLong))
// .withWatermark("timeStamp", "1 minutes")
.groupByKey(event => event.province)
.flatMapGroupsWithState(outputMode = OutputMode.Update(), timeoutConf = GroupStateTimeout.NoTimeout())(topNCountPerProvince)
// start the query
val qs = logs.writeStream
.format("console")
.outputMode("update")
.start()
// wait for the termination
qs.awaitTermination()
}
/**
*
* @param Province key
* @param events values
* @param state the state
* @return top2 city of each province
*/
def topNCountPerProvince(Province:String, events: Iterator[Events], state: GroupState[State]): Iterator[Update] ={
val oldState = if (state.exists) state.get else State(Province, mutable.Map[String, Int]())
val cityMaps = oldState.cityCounts
//
events
.toSeq
.groupBy(events => events.city)
.map(f => (f._1, f._2.size))
.foreach(v => {
val city = v._1
val count = v._2
if (cityMaps.contains(city)){
cityMaps(city) += count
}else{
cityMaps.getOrElseUpdate(city, count)
}
})
val newState = State(Province, cityMaps)
state.update(newState)
/**
* 这里取count前二的city,count一样的同时取出来,大概长这样
* +----------+-----------+---------+
* | province | city | count |
* +----------+-----------+---------+
* | beijing | xicheng | 400 |
* +----------+-----------+---------+
* | beijing | dongcheng | 332 |
* +----------+-----------+---------+
* | beijing | chaoyang | 332 |
* +----------+-----------+---------+
*/
val output = cityMaps.groupBy(_._2)
.toList
.sortWith(_._1 > _._1)
.take(2)
.flatMap(f => f._2.toSeq)
.map(v => Update(Province, v._1, v._2))
output.toIterator
}
def getTimepstamp(tm: String): Timestamp ={
new Timestamp(tm.toLong)
}
}
case class State(province: String, cityCounts: mutable.Map[String, Int])
case class log(key: String, value: String)
case class msgs(key: String, value :String)
case class Events(timeStamp:Timestamp, province:String, city:String, userid:Long, adId: Long)
case class Update(province: String, city: String, count: Int)