在上一篇文章中,首先按照netType进行了统计,接下来添加一个条件,按照城市进行统计:
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("TopNStatJob")
.config("spark.sql.sources.partitionColumnTypeInference.enabled", "false")
.master("local[2]").getOrCreate()
val accessDF = spark.read.format("parquet").load("file:///E:/test/clean")
// accessDF.printSchema()
accessDF.show(false)
// 最受欢迎的TopN netType
// netTypeAccessTopNStat(spark, accessDF)
// 按照地市进行统计TopN课程
cityTypeAccessTopNStat(spark, accessDF)
spark.stop
}
/**
* 按照地市进行统计Top3课程
*
* @param spark
* @param accessDF
*/
def cityTypeAccessTopNStat(spark: SparkSession, accessDF: DataFrame): Unit = {
val cityAccessTopNDF = accessDF.filter(accessDF.col("day") === "20190702" && accessDF.col("netType") === "wifi")
.groupBy("day", "uid", "city").agg(count("uid").as("times")).orderBy(desc("times"))
cityAccessTopNDF.show(false)
// window 函数在Spark SQL的使用
cityAccessTopNDF.select(cityAccessTopNDF("day")
, cityAccessTopNDF("uid")
, cityAccessTopNDF("city")
, cityAccessTopNDF("times")
, row_number()
.over(Window.partitionBy("city")
.orderBy(cityAccessTopNDF("times").desc))
.as("times_rank")
).filter("times_rank <= 3")
.show(false)
}
运行结果如下:
将结果写入mysql
创建数据表:
create table day_netType_city_access_topn_stat (
day varchar(8) not null,
uid bigint(10) not null,
city varchar(20) not null,
times bigint(10) not null,
times_rank bigint(10) not null,
primary key (day, uid)
)
创建一个Entity
package cn.ac.iie.log
case class DayCityNetTypeAccessStat(day:String, uid: Long, city:String, times: Long, times_rank: Long)
创建Dao
**
* 批量保存DayCityNetTypeAccessStat到数据库
*
* @param list
*/
def insertDayNetTypeCityAccessTopN(list: ListBuffer[DayCityNetTypeAccessStat]): Unit = {
var connection: Connection = null
var pstmt: PreparedStatement = null
try {
connection = MysqlUtils.getConnection()
// 设置手动提交
connection.setAutoCommit(false)
val sql = "insert into day_netType_city_access_topn_stat (day, uid, city, times, times_rank) values (?,?,?,?,?)"
pstmt = connection.prepareStatement(sql)
for (ele <- list) {
pstmt.setString(1, ele.day)
pstmt.setLong(2, ele.uid)
pstmt.setString(3, ele.city)
pstmt.setLong(4, ele.times)
pstmt.setLong(5, ele.times_rank)
pstmt.addBatch()
}
pstmt.executeBatch() // 执行批量处理
// 手动提交
connection.commit()
} catch {
case e: Exception => e.printStackTrace()
} finally {
MysqlUtils.release(connection, pstmt)
}
}
将结果写入到Mysql中
// 将统计结果写入到Mysql中
try {
top3DF.foreachPartition(partitionOfRecords => {
val list = new ListBuffer[DayCityNetTypeAccessStat]
partitionOfRecords.foreach(info => {
val day = info.getAs[String]("day")
val uid = info.getAs[String]("uid").toLong
val city = info.getAs[String]("city")
val times = info.getAs[Long]("times")
val timesRank = info.getAs[Int]("times_rank")
list.append(DayCityNetTypeAccessStat(day, uid, city, times, timesRank))
})
StatDao.insertDayNetTypeCityAccessTopN(list)
})
} catch {
case e: Exception => e.printStackTrace()
}