1 Spark SQL概述
1.1 定义
Spark SQL是用于结构化数据处理的Spark模块,与基本的Spark RDD API不同,Spark SQL提供的接口为Spark童工了有关数据结构和正在执行的计算的更多信息。
在内部,Spark SQL使用这些额外的信息来执行额外的优化。与Spark SQL交互的方式有多种,包括SQL和Dataset API。计算结果时,使用相同的执行引擎,与您用于表达计算的API/语言无关。
1.2 为什么要有Spark SQL?
为了实现类Hive的功能
Hadoop --> MapReduce--> Hive(HQL)
Spark --> RDD --> Shark(Spark+HIve)--> Spark SQL
--> Hive On Spark
Spark on Hive:使用Spark操作Hive中的表格,Hive只作为存储元数据,Spark负责SQL解析优化,语法是Spark SQL语法,Spark底层采用优化后的 df 或 fs 执行。
Hive on Spark:使用Hive操作,只是将原有的MR引擎替换为Spark。Hive既作为存储元数据又负责SQL的解析优化,语法是HQL语法,执行引擎变为Spark,Spark负责采用RDD执行。
1.3 Spark的发展历史
RDD(Spark 1.0) --> DataFrame(Spark 1.3) --> DataSet(Spark 1.6)
如果同样的数据都给到这三个数据结构,在她们分别计算之后,都会给出相同的结果。不同的是他们的执行效率和执行方式。在现在的版本中,SparkSet的性能最好,已经称为了唯一使用的接口。其中DataFrame已经在底层被看作是特殊泛型的DataSet<Row>。
三者的共性:
- RDD、DataFrame、DataSet全都是Spark平台下的分布式弹性数据集,为处理超大型数据提供便利。
- 三者都有惰性机制,在进行创建、转换时(如map方法),不会立即执行,只有在遇到Action行动算子如foreach时,三者才会开始遍历运算。
- 三者有许多共同的函数,如filter,sort等。
- 三者都会根据Spark的内存情况自动缓存运算。
- 三者都有分区的概念。
1.4 Spark SQL的特点
1 易整合:无缝的整合了SQL查询和Spark编程
2 统一的数据访问形式:使用不同的方式连接不同的数据源
3 兼容Hive:在已有的仓库上直接运行SQL或HQL
4 标准的数据连接:通过JDBC或ODBC来连接
2 Spark SQL编程
2.1 SparkSession
在老版本中,SparkSQL提供隆重SQL查询起始点:
SQLContext:用于Spark自己提供的SQL查询
HiveContext:用于连接Hive的查询
SparkSession是Spark最新的SQL查询起始点,实质上是SQLContext和HiveContext的组合,所以在SQLContext和HiveContext上可用的API在SparkSession上同样是可以使用的。
SparkSession内部封装了SparkContext,所以计算实际上是由SparkContext完成的。当使用Spark-Shell时,Spark框架会自动的创建一个名称叫做Spark的SparkSession,就像我们以前可以自动获取到一个sc来表示SparkContext。
2.2 常用方式
2.2.1 方法调用
1 创建Maven工程
2 在项目SparkSQLTest上点击右键,Add Framework Support --> Scala
3 在main下创建Scala文件夹,并make Directory as Source Root --> 在Scala下创建包
在pom.xml文件添加spark-sql依赖和scala编译插件
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.12</artifactId>
<version>3.1.3</version>
</dependency>
</dependencies>
<build>
<finalName>SparkSQLTest</finalName>
<plugins>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<version>3.4.6</version>
<executions>
<execution>
<goals>
<goal>compile</goal>
<goal>testCompile</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
调整日志输出等级 log4j.properties
log4j.rootCategory=ERROR, console
log4j.appender.console=org.apache.log4j.ConsoleAppender
log4j.appender.console.target=System.err
log4j.appender.console.layout=org.apache.log4j.PatternLayout
log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
# Set the default spark-shell log level to ERROR. When running the spark-shell, the
# log level for this class is used to overwrite the root logger's log level, so that
# the user can have different defaults for the shell and regular Spark apps.
log4j.logger.org.apache.spark.repl.Main=ERROR
# Settings to quiet third party logs that are too verbose
log4j.logger.org.spark_project.jetty=ERROR
log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR
log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=ERROR
log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=ERROR
log4j.logger.org.apache.parquet=ERROR
log4j.logger.parquet=ERROR
# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support
log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL
log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR
在SparkSQL中DS直接支持的算子有:map(底层优化为mapPartition)、mapPartition、flatMap、groupByKey(聚合算子全部由groupByKey开始)、filter、distinct、coalesce、repartition、sort、和sortBy(不是函数式算子,但是不影响使用)。
区分算子和后面学习的特殊语法:需要填函数的是算子,直接填字段的是特殊语法。
2.2.2 SQL使用方式
package com.atguigu.sparksql
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, SparkSession}
object Test02_SQL {
def main(args: Array[String]): Unit = {
// 1. 创建sparkSession的配置对象
val conf: SparkConf = new SparkConf().setAppName("sparkSql").setMaster("local[*]")
// 2. 获取sparkSession
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
import spark.implicits._
// 3. 编写代码
// 对于sql处理来说,数据不需要有具体的对象,使用Row类型的DF即可
val df: DataFrame = spark.read.json("input/user.json")
// 创建视图 => 转换为表格 填写表名
// 临时视图的生命周期和当前的sparkSession绑定
// orReplace表示覆盖之前相同名称的视图
df.createOrReplaceTempView("user")
// 创建全局视图
// 全局视图的生命周期和sparkApplication绑定 在本案例中没有区别
// 全局视图调用表格时需要添加前缀global_temp.
df.createOrReplaceGlobalTempView("user1")
// 编写sql => 支持所有的hiveSQL语法,并且会使用spark优化器
val df1: DataFrame = spark.sql(
"""
|select name,
| age+10 newAge
|from
| user
|""".stripMargin)
df1.show()
// 4.关闭sparkSession
spark.stop()
}
}
2.2.3 DSL特殊语法
package com.atguigu.sparksql
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Dataset, Row, SparkSession}
object Test03_DSL {
def main(args: Array[String]): Unit = {
// 1. 创建sparkSession的配置对象
val conf: SparkConf = new SparkConf().setAppName("sparkSql").setMaster("local[*]")
// 2. 获取sparkSession
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
// 导入隐藏依赖
import spark.implicits._
// 3. 编写代码
// 创建ds
val ds: Dataset[User] = spark.read.json("input/user.json").as[User]
// 特殊语法有两种写法 使用没有区别
val ds1: Dataset[Row] = ds.select($"name", $"age" + 20 as "newAge")
.where($"age" < 19)
.sort("age")
val ds2: Dataset[Row] = ds.select('name, 'age + 20 as 'newAge)
.where('age < 19)
.sort('newAge)
ds2.show()
// 分组聚合
ds.groupBy($"name")
.avg("age" )
.select($"name",$"avg(age)" as "avg_age")
.show()
// 4.关闭sparkSession
spark.stop()
}
}
2.3 SQL语法的用户定义函数
2.3.1 UDF
一行进入,一行输出。
package com.atguigu.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, SparkSession}
object Test04_UDF {
def main(args: Array[String]): Unit = {
// 1 创建Spark Session的配置对象
// local | yarn
val conf: SparkConf = new SparkConf()
.setAppName("sparkSQL")
.setMaster("local[*]")
// 2 获取SparkSession
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
// 导入特殊依赖,隐式转换
import spark.implicits._
// 3 代码
// 读取数据
val df: DataFrame = spark.read.json("input/user.json")
// 创建临时视图
df.createOrReplaceTempView("user")
// 注册UDF函数
spark.udf.register("addName", (x: String) => "Name: " + x)
// 调用UDF函数
spark.sql(
"""
|select
| addName(name),
| age
|from user
|""".stripMargin).show()
// 4 关闭SparkSession
spark.stop()
}
}
2.3.2 UDAF
输入多行,返回一行。通常和groupBy一起使用,如果直接使用UDAF函数,默认将所有的数据合并在一起。
Spark 3.x 推荐使用 extends Aggregaor 自定义UDAF,属于强类型的DataSet方式。
Spark 2.x 使用 extends UserDefinedAggregateFunction,属于弱类型的DataFrame。
package com.atguigu.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{DataFrame, Encoder, Encoders, SparkSession, functions}
object Test05_UDAF {
def main(args: Array[String]): Unit = {
// 1 创建Spark Session的配置对象
// local | yarn
val conf: SparkConf = new SparkConf()
.setAppName("sparkSQL")
.setMaster("local[*]")
// 2 获取SparkSession
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
// 导入特殊依赖,隐式转换
import spark.implicits._
// 3 代码
val df: DataFrame = spark.read.json("input/user.json")
df.createOrReplaceTempView("user")
spark.udf register("myAvg", functions.udaf(new MyAvgUDAF()))
spark.sql(
"""
|select
| myAvg(age)
|from user
|""".stripMargin).show()
// 4 关闭SparkSession
spark.stop()
}
}
case class Buff(var sum: Long, var count: Long)
class MyAvgUDAF extends Aggregator[Long, Buff, Double] {
// 初始化缓冲区
override def zero: Buff = Buff(0L, 0L)
// 将输入的年龄和缓冲区的数据进行聚合
override def reduce(buff: Buff, age: Long): Buff = {
buff.sum = buff.sum + age
buff.count = buff.count + 1
buff
}
// 多个缓冲区数据合并
override def merge(buff1: Buff, buff2: Buff): Buff = {
buff1.sum = buff1.sum + buff2.sum
buff1.count = buff1.count + buff2.count
buff1
}
// 完成聚合操作,获取最终结果
override def finish(reduction: Buff): Double = {
reduction.sum.toDouble / reduction.count
}
// SparkSQL对传递的对象的序列化操作(编码)
// 自定义类型就是product,自带类型根据类型选择
override def bufferEncoder: Encoder[Buff] = Encoders.product
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
2.3.3 UDTF(没有)
输入一行,返回多行(Hive)。SparkSQL中没有UDFT,需要使用算子类型的flatMap先完成拆分。
3 SparkSQL数据的加载和保存
3.1 读取和保存文件
SparkSQL读取和保存的文件一般为三种,JSON文件、CSV文件和类似存储文件,同时可以通过添加参数,来识别不同的存储和压缩格式。
3.1.1 CSV文件
package com.atguigu.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, DataFrameReader, Dataset, SaveMode, SparkSession}
object Test06_csv {
def main(args: Array[String]): Unit = {
// 1 创建Spark Session的配置对象
// local | yarn
val conf: SparkConf = new SparkConf()
.setAppName("sparkSQL")
.setMaster("local[*]")
// 2 获取SparkSession
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
// 导入特殊依赖,隐式转换
import spark.implicits._
// 3 代码
val reader: DataFrameReader = spark.read
val df: DataFrame = reader
// 默认为 , 分割
.option("sep", ";")
// 默认为false,没有读取第一行为列名的功能
.option("header", false)
// 不需要填写压缩格式,自适应
.csv("input/1.txt")
// 添加列名才能转换为ds
val df1: DataFrame = df.toDF("age", "name")
val ds: Dataset[User] = df1.as[User]
ds.show()
ds.write
// 默认false是否写出头信息
.option("header",true)
// 默认为写出数据的分隔符
.option("sep",";")
.option("compression","gzip")
// 四种写出模式:append 追加 Ignore 忽略本次写出 Overwrite 覆盖写 ErrorIfExists 如果存在报错
.mode(SaveMode.Append)
.csv("output")
// 4 关闭SparkSession
spark.stop()
}
}
// csv文件读取的数据需要用String接受
case class User(name: String, age: String)
3.1.2 JSON文件
package com.atguigu.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, DataFrameReader, Dataset, SparkSession}
object Test07_Json {
def main(args: Array[String]): Unit = {
// 1 创建Spark Session的配置对象
// local | yarn
val conf: SparkConf = new SparkConf()
.setAppName("sparkSQL")
.setMaster("local[*]")
// 2 获取SparkSession
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
// 导入特殊依赖,隐式转换
import spark.implicits._
// 3 代码
val df: DataFrame = spark.read.json("input/user.json")
println(df.schema)
df.show()
val ds: Dataset[User] = df.as[User]
ds.write
.json("output")
// 4 关闭SparkSession
spark.stop()
}
}
3.1.3 Parquet文件
列式存储的数据自带列分割。
package com.atguigu.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
object Test08_Parquet {
def main(args: Array[String]): Unit = {
// 1 创建Spark Session的配置对象
// local | yarn
val conf: SparkConf = new SparkConf()
.setAppName("sparkSQL")
.setMaster("local[*]")
// 2 获取SparkSession
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
// 导入特殊依赖,隐式转换
import spark.implicits._
// 3 代码
val df: DataFrame = spark.read.json("input/user.json")
val ds: Dataset[User] = df.as[User]
// 写出默认使用snappy
ds.write.parquet("output")
// 字解析,性能优秀
println(spark.read.parquet("output").schema)
spark.read.parquet("output").show()
// 4 关闭SparkSession
spark.stop()
}
}
3.2 与MySQL交互
导入依赖
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>5.1.27</version>
</dependency>
package com.atguigu.spark.sql
import java.util.Properties
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
object Test09_Table {
def main(args: Array[String]): Unit = {
// 1 创建Spark Session的配置对象
// local | yarn
val conf: SparkConf = new SparkConf()
.setAppName("sparkSQL")
.setMaster("local[*]")
// 2 获取SparkSession
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
// 导入特殊依赖,隐式转换
import spark.implicits._
// 3 代码
val df: DataFrame = spark.read.json("input/user.json")
val properties = new Properties()
properties.setProperty("user", "root")
properties.setProperty("password", "123456")
df.write
// 如果是覆盖写模式
.mode(SaveMode.Append)
.jdbc("jdbc:mysql://hadoop102:3306", "gmall.testInfo", properties)
// 读取mysql
val df1: DataFrame = spark.read.jdbc("jdbc:mysql://hadoop102:3306", "gmall.user_info", properties)
df1.show()
println(df1.schema)
// 4 关闭SparkSession
spark.stop()
}
}
3.3 与Hive交互
SparkSQL可以采用内嵌Hive(Spark开箱机用的Hive),也可以采用外部Hive。企业开发中,通常采用外部Hive。
添加依赖
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.12</artifactId>
<version>3.1.3</version>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>5.1.27</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_2.12</artifactId>
<version>3.1.3</version>
</dependency>
</dependencies>
package com.atguigu.sparksql
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, SparkSession}
object Test10_Hive {
def main(args: Array[String]): Unit = {
System.setProperty("HADOOP_USER_NAME","atguigu")
// 1. 创建sparkSession的配置对象
val conf: SparkConf = new SparkConf()
.setAppName("sparkSql").setMaster("local[*]")
// 2. 获取sparkSession
val spark: SparkSession = SparkSession
.builder().config(conf).enableHiveSupport().getOrCreate()
import spark.implicits._
// 3. 编写代码
val df: DataFrame = spark
.sql("show tables")
df.show()
spark.sql("insert into table user_info values('zhangsan',10)")
spark.sql("""select * from user_info""").show()
// 4.关闭sparkSession
spark.stop()
}
}
4 SparkSQL项目实战
需求:各区域热门商品Top3
使用Spark-SQL来完成复杂的需求,可以使用UDF或UDAF
- 查询出来所有的点击记录,并与city_info表连接,得到每个城市所在的地区,与 Product_info表连接得到商品名称。
- 按照地区和商品名称分组,统计出每个商品在每个地区的总点击次数。
- 每个地区内按照点击次数降序排列。
- 只取前三名,并把结果保存在数据库中。
- 城市备注需要自定义UDAF函数。
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession, functions}
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
object Test_Top03 {
def main(args: Array[String]): Unit = {
// 1 创建Spark Session的配置对象
// local | yarn
val conf: SparkConf = new SparkConf()
.setAppName("sparkSQL")
.setMaster("local[*]")
// 2 获取SparkSession
val spark: SparkSession = SparkSession.builder().enableHiveSupport().config(conf).getOrCreate()
// 导入特殊依赖,隐式转换
import spark.implicits._
spark.sql("use default")
// 3 代码
// 注册自定义聚合函数
spark.udf.register("city_remark", functions.udaf(new CityRemarkUDAF()))
// 1 查询点击记录并和城市表,产品表内连接
spark.sql(
"""
|select
| c.area,--地区
| c.city_name,--城市
| p.product_name,--商品名称
| v.click_product_id--点击商品id
|from user_visit_action v
|join city_info c
|on v.city_id=c.city_id
|join product_info p
|on v.click_product_id=p.product_id
|where click_product_id>-1
|""".stripMargin).createOrReplaceTempView("t1")
// 2 分组计算每个区域,每个产品的点击量
spark.sql(
"""
|select
| t1.area,
| t1.product_name,
| count(*) click_count,
| city_remark(t1.city_name) city_remark
|from t1
|group by t1.area,t1.product_name
|""".stripMargin).show()
// 4 关闭SparkSession
spark.stop()
}
}
// 中间缓存数据
case class Buffer(var totalcnt: Long, var cityMap: mutable.Map[String, Long])
class CityRemarkUDAF extends Aggregator[String, Buffer, String] {
override def zero: Buffer = Buffer(0L, mutable.Map[String, Long]())
override def reduce(buffer: Buffer, city: String): Buffer = {
// 总点击次数
buffer.totalcnt += 1
// 每个城市的点击次数
var newCount: Long = buffer.cityMap.getOrElse(city, 0L) + 1
buffer.cityMap.update(city, newCount)
buffer
}
override def merge(b1: Buffer, b2: Buffer): Buffer = {
// 合并所有城市的点击数量的综合
b1.totalcnt += b2.totalcnt
// 合并城市Map(2个Map合并)
b2.cityMap.foreach {
case (city, count) => {
val newCnt: Long = b1.cityMap.getOrElse(city, 0L) + count
b1.cityMap.update(city, newCnt)
}
}
b1
}
override def finish(reduction: Buffer): String = {
val remarkList: ListBuffer[String] = ListBuffer[String]()
//将统计的城市点击数量的集合进行排序,并取出前两名
val cityCountList: List[(String, Long)] = reduction.cityMap.toList.sortWith {
(left, right) => {
left._2 > right._2
}
}.take(2)
var sum: Long = 0L
// 计算出前两名的百分比
cityCountList.foreach {
case (city, cnt) => {
val r = cnt * 100 / reduction.totalcnt
remarkList.append(city + " " + r + "%")
sum += r
}
}
// 如果城市个数大于2,用其他表示
if (reduction.cityMap.size > 2) {
remarkList.append("其他" + (100 - sum) + "%")
}
remarkList.mkString(",")
}
override def bufferEncoder: Encoder[Buffer] = Encoders.product
override def outputEncoder: Encoder[String] = Encoders.STRING
}
// 简化版
package com.atguigu.sparksql
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Encoder, Encoders, SparkSession, functions}
import org.apache.spark.sql.expressions.Aggregator
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
object SparkSQL13_TopN_Layne {
def main(args: Array[String]): Unit = {
System.setProperty("HADOOP_USER_NAME", "atguigu")
//TODO 1 创建SparkConf配置文件,并设置App名称
val conf = new SparkConf().setAppName("SparkSQLTest").setMaster("local[*]")
//TODO 2 利用SparkConf创建sparksession对象
val spark: SparkSession = SparkSession.builder().enableHiveSupport().config(conf).getOrCreate()
//注册自定义UDAF函数
spark.udf.register("city_remark", functions.udaf(new CityRemarkUDAF()))
//执行sparksql
spark.sql(
"""
|select
| t3.area,
| t3.product_name,
| t3.click_count,
| t3.city_remark
|from
|(
| select
| t2.area,
| t2.product_name,
| t2.click_count,
| t2.city_remark,
| rank() over(partition by t2.area order by t2.click_count desc) rk
| from
| (
| select
| t1.area,
| t1.product_name,
| count(*) click_count,
| city_remark(t1.city_name) city_remark
| from
| (
| select
| c.area,
| c.city_name,
| p.product_name,
| v.click_product_id
| from user_visit_action v
| join city_info c
| on v.city_id = c.city_id
| join product_info p
| on v.click_product_id = p.product_id
| where v.click_product_id > -1
| )t1
| group by t1.area,t1.product_name
| ) t2
|) t3
|where t3.rk <= 3
|""".stripMargin).show(1000, false)
// 3 关闭资源
spark.stop()
}
}
/**
* 输入: city_name String
* 缓冲区: 这个大区的总点击量 Map[(City_name,城市的点击量)] Buffer
* 输出: city_remark String
*/
case class Buffer(var totalcnt: Long, var cityMap: mutable.Map[String, Long])
class CityRemarkUDAF extends Aggregator[String, Buffer, String] {
override def zero: Buffer = Buffer(0L, mutable.Map[String, Long]())
//单个分区聚合方法 buffer 和 city
override def reduce(buffer: Buffer, city: String): Buffer = {
buffer.totalcnt += 1
buffer.cityMap(city) = buffer.cityMap.getOrElse(city, 0L) + 1
buffer
}
//多个buffer之间的聚合
override def merge(b1: Buffer, b2: Buffer): Buffer = {
b1.totalcnt += b2.totalcnt
b2.cityMap.foreach {
case (city, cityCnt) => {
b1.cityMap(city) = b1.cityMap.getOrElse(city, 0L) + cityCnt
}
}
b1
}
//最终逻辑计算方法
override def finish(buffer: Buffer): String = {
//0 定义一个listBuffer 用来存储最后返回结果
val remarkList = ListBuffer[String]()
//1 取出buffer的map,然后转成list,按照城市点击量 倒序排序
val cityList: List[(String, Long)] = buffer.cityMap.toList.sortWith(
(t1, t2) => t1._2 > t2._2
)
//2 取出排好序的cityList前两名,特殊处理
var sum = 0L
cityList.take(2).foreach {
case (city, cityCnt) => {
val res: Long = cityCnt * 100 / buffer.totalcnt
remarkList.append(city + " " + res + "%")
sum += res //将前两个加起来,方便计算第三个其他百分比
}
}
//3 计算第三个其他的百分比
if (buffer.cityMap.size > 2) {
remarkList.append("其他 " + (100 - sum) + "%")
}
//4 返回remarkList字符串
remarkList.mkString(",")
}
override def bufferEncoder: Encoder[Buffer] = Encoders.product
override def outputEncoder: Encoder[String] = Encoders.STRING
}
// 算子实现自定义函数
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Dataset, Encoders, KeyValueGroupedDataset, Row, SparkSession}
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
object Test02_Top3 {
def main(args: Array[String]): Unit = {
// 1 创建Spark Session的配置对象
// local | yarn
val conf: SparkConf = new SparkConf()
.setAppName("sparkSQL")
.setMaster("local[*]")
// 2 获取SparkSession
val spark: SparkSession = SparkSession.builder().enableHiveSupport().config(conf).getOrCreate()
spark.sql("use default")
// 导入特殊依赖,隐式转换
import spark.implicits._
// 3 代码
// 读取数据使用SQL最方便
val df: DataFrame = spark.sql(
"""
| select
| c.area, --地区
| c.city_name, -- 城市
| p.product_name, -- 商品名称
| v.click_product_id -- 点击商品id
| from user_visit_action v
| join city_info c
| on v.city_id = c.city_id
| join product_info p
| on v.click_product_id = p.product_id
| where click_product_id > -1
|""".stripMargin)
val value: KeyValueGroupedDataset[(String, String), Row] = df.groupByKey(row => {
(row.getString(0), row.getString(2))
})
// tuple:(area, product_name)
val groupMarkDS: Dataset[AreaProductInfo] = value.mapGroups((tuple: (String, String), it: Iterator[Row]) => {
// 声明总数 和hashMap统计城市百分比
var totalCount = 0L
val hashMap = new mutable.HashMap[String, Long]()
// 循环遍历一组的每一行数据
it.foreach(row => {
totalCount += 1
val cityName: String = row.getString(1)
hashMap.put(cityName, 1L + hashMap.getOrElse(cityName, 0L))
})
// 城市标记
val remarkList = ListBuffer[String]()
// 将统计的城市点击数量的集合进行排序,并取出前两名
val cityCountList: List[(String, Long)] = hashMap.toList.sortWith(
(left, right) => {
left._2 > right._2
}
).take(2)
var sum: Long = 0L
// 计算出前两名的百分比
cityCountList.foreach {
case (city, cnt) => {
val r = cnt * 100 / totalCount
remarkList.append(city + " " + r + "%")
sum += r
}
}
// 如果城市个数大于2,用其他表示
if (hashMap.size > 2) {
remarkList.append("其他 " + (100 - sum) + "%")
}
val remarkStr: String = remarkList.mkString(",")
// 创建返回值
AreaProductInfo(tuple._1,tuple._2,totalCount,remarkStr)
})(Encoders.product[AreaProductInfo])
groupMarkDS.toDF().createOrReplaceTempView("t2")
spark.sql(
"""
|select
| *,
| rank() over(partition by t2.area order by t2.click_count desc) rank
|from t2
|""".stripMargin).createOrReplaceTempView("t3")
spark.sql(
"""
|select
| *
|from t3
|where rank<=3
|""".stripMargin).show()
// 4 关闭SparkSession
spark.stop()
}
}
case class AreaProductInfo(area: String, product_name: String, click_count: Long, cityMark: String)