package com.ws.sparksql
import com.ws.spark.IpFromUtils
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
/**
* sql 统计日志中ip归属地出现次数
*/
object SqlIpFromCount {
def main(args: Array[String]): Unit = {
val sparkSession = SparkSession.builder().appName("SqlIpFromCount").master("local[*]").getOrCreate()
import sparkSession.implicits._
//读取规则
val rulesData: Dataset[String] = sparkSession.read.textFile("E:\\bigData\\testdata\\ip.txt")
val rules: Dataset[(Long, Long, String)] = rulesData.map(l => {
val fields = l.split("[|]")
val beginNum = fields(2).toLong
val endNum = fields(3).toLong
val province = fields(6)
(beginNum, endNum, province)
})
val rulesCollect: Array[(Long, Long, String)] = rules.collect()
//广播变量,只能用sc实例
val broadCast: Broadcast[Array[(Long, Long, String)]] = sparkSession.sparkContext.broadcast(rulesCollect)
//读取数据
val data: Dataset[String] = sparkSession.read.textFile("E:\\bigData\\testdata\\access.log")
val ipNum: Dataset[Long] = data.map(l => {
val fields = l.split("[|]")
val ip = fields(1)
//ip转十进制
val ipNum = IpFromUtils.ipToLong(ip)
ipNum
})
val ipNumDataFrame: DataFrame = ipNum.toDF("ip_num")
ipNumDataFrame.createTempView("t_ips")
//定义一个sql函数
sparkSession.udf.register("iptoProvince", (ipNum: Long) => {
//获取Driver端广播的变量
val rulesBroad: Array[(Long, Long, String)] = broadCast.value
val index = IpFromUtils.binarySearch(rulesBroad, ipNum)
var province = "暂无"
if (index != -1) {
province = rulesBroad(index)._3
}
province
})
//broadcastJoin
val result = sparkSession.sql("select iptoProvince(ip_num) province , count(*) as times from t_ips group by province order by times desc")
result.show()
sparkSession.stop()
}
}
结果 :
+--------+-----+
|province|times|
+--------+-----+
| 陕西| 1824|
| 北京| 1535|
| 重庆| 868|
| 河北| 383|
| 云南| 126|
+--------+-----+