Spark SQL

1 Spark SQL概述

1.1 定义

Spark SQL是用于结构化数据处理的Spark模块,与基本的Spark RDD API不同,Spark SQL提供的接口为Spark童工了有关数据结构和正在执行的计算的更多信息。

在内部,Spark SQL使用这些额外的信息来执行额外的优化。与Spark SQL交互的方式有多种,包括SQL和Dataset API。计算结果时,使用相同的执行引擎,与您用于表达计算的API/语言无关。

1.2 为什么要有Spark SQL?

为了实现类Hive的功能
Hadoop	--> MapReduce--> Hive(HQL)
Spark	--> RDD		--> Shark(Spark+HIve)--> Spark SQL
										 --> Hive On Spark

Spark on Hive:使用Spark操作Hive中的表格,Hive只作为存储元数据,Spark负责SQL解析优化,语法是Spark SQL语法,Spark底层采用优化后的 df 或 fs 执行。

Hive on Spark:使用Hive操作,只是将原有的MR引擎替换为Spark。Hive既作为存储元数据又负责SQL的解析优化,语法是HQL语法,执行引擎变为Spark,Spark负责采用RDD执行。

在这里插入图片描述

1.3 Spark的发展历史

RDD(Spark 1.0) --> DataFrame(Spark 1.3) --> DataSet(Spark 1.6)

如果同样的数据都给到这三个数据结构,在她们分别计算之后,都会给出相同的结果。不同的是他们的执行效率和执行方式。在现在的版本中,SparkSet的性能最好,已经称为了唯一使用的接口。其中DataFrame已经在底层被看作是特殊泛型的DataSet<Row>。

三者的共性:

  1. RDD、DataFrame、DataSet全都是Spark平台下的分布式弹性数据集,为处理超大型数据提供便利。
  2. 三者都有惰性机制,在进行创建、转换时(如map方法),不会立即执行,只有在遇到Action行动算子如foreach时,三者才会开始遍历运算。
  3. 三者有许多共同的函数,如filter,sort等。
  4. 三者都会根据Spark的内存情况自动缓存运算。
  5. 三者都有分区的概念。

1.4 Spark SQL的特点

1 易整合:无缝的整合了SQL查询和Spark编程

2 统一的数据访问形式:使用不同的方式连接不同的数据源

3 兼容Hive:在已有的仓库上直接运行SQL或HQL

4 标准的数据连接:通过JDBC或ODBC来连接

2 Spark SQL编程

2.1 SparkSession

在老版本中,SparkSQL提供隆重SQL查询起始点:

SQLContext:用于Spark自己提供的SQL查询

HiveContext:用于连接Hive的查询

SparkSession是Spark最新的SQL查询起始点,实质上是SQLContext和HiveContext的组合,所以在SQLContext和HiveContext上可用的API在SparkSession上同样是可以使用的。

SparkSession内部封装了SparkContext,所以计算实际上是由SparkContext完成的。当使用Spark-Shell时,Spark框架会自动的创建一个名称叫做Spark的SparkSession,就像我们以前可以自动获取到一个sc来表示SparkContext。

2.2 常用方式

2.2.1 方法调用

1 创建Maven工程
2 在项目SparkSQLTest上点击右键,Add Framework Support --> Scala
3 在main下创建Scala文件夹,并make Directory as Source Root --> 在Scala下创建包
在pom.xml文件添加spark-sql依赖和scala编译插件

<dependencies>
    <dependency>
        <groupId>org.apache.spark</groupId>
        <artifactId>spark-sql_2.12</artifactId>
        <version>3.1.3</version>
    </dependency>
</dependencies>

<build>
<finalName>SparkSQLTest</finalName>
<plugins>
        <plugin>
            <groupId>net.alchim31.maven</groupId>
            <artifactId>scala-maven-plugin</artifactId>
            <version>3.4.6</version>
            <executions>
                <execution>
                    <goals>
                        <goal>compile</goal>
                        <goal>testCompile</goal>
                    </goals>
                </execution>
            </executions>
        </plugin>
    </plugins>
</build>
调整日志输出等级 log4j.properties

log4j.rootCategory=ERROR, console
log4j.appender.console=org.apache.log4j.ConsoleAppender
log4j.appender.console.target=System.err
log4j.appender.console.layout=org.apache.log4j.PatternLayout
log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n

# Set the default spark-shell log level to ERROR. When running the spark-shell, the
# log level for this class is used to overwrite the root logger's log level, so that
# the user can have different defaults for the shell and regular Spark apps.
log4j.logger.org.apache.spark.repl.Main=ERROR

# Settings to quiet third party logs that are too verbose
log4j.logger.org.spark_project.jetty=ERROR
log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR
log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=ERROR
log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=ERROR
log4j.logger.org.apache.parquet=ERROR
log4j.logger.parquet=ERROR

# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support
log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL
log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR

在SparkSQL中DS直接支持的算子有:map(底层优化为mapPartition)、mapPartition、flatMap、groupByKey(聚合算子全部由groupByKey开始)、filter、distinct、coalesce、repartition、sort、和sortBy(不是函数式算子,但是不影响使用)。

区分算子和后面学习的特殊语法:需要填函数的是算子,直接填字段的是特殊语法。
在这里插入图片描述

2.2.2 SQL使用方式

package com.atguigu.sparksql

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, SparkSession}

object Test02_SQL {
  def main(args: Array[String]): Unit = {
    // 1. 创建sparkSession的配置对象
    val conf: SparkConf = new SparkConf().setAppName("sparkSql").setMaster("local[*]")

    // 2. 获取sparkSession
    val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()

    import spark.implicits._
    // 3. 编写代码
    // 对于sql处理来说,数据不需要有具体的对象,使用Row类型的DF即可
    val df: DataFrame = spark.read.json("input/user.json")

    // 创建视图 => 转换为表格 填写表名
    // 临时视图的生命周期和当前的sparkSession绑定
    // orReplace表示覆盖之前相同名称的视图
    df.createOrReplaceTempView("user")

    // 创建全局视图
    // 全局视图的生命周期和sparkApplication绑定  在本案例中没有区别
    // 全局视图调用表格时需要添加前缀global_temp.
    df.createOrReplaceGlobalTempView("user1")

    // 编写sql  => 支持所有的hiveSQL语法,并且会使用spark优化器
    val df1: DataFrame = spark.sql(
      """
        |select name,
        |   age+10 newAge
        |from
        |   user
        |""".stripMargin)

    df1.show()

    // 4.关闭sparkSession
    spark.stop()
  }
}

2.2.3 DSL特殊语法

package com.atguigu.sparksql

import org.apache.spark.SparkConf
import org.apache.spark.sql.{Dataset, Row, SparkSession}

object Test03_DSL {
  def main(args: Array[String]): Unit = {
    // 1. 创建sparkSession的配置对象
    val conf: SparkConf = new SparkConf().setAppName("sparkSql").setMaster("local[*]")

    // 2. 获取sparkSession
    val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()

    // 导入隐藏依赖
    import spark.implicits._

    // 3. 编写代码
    // 创建ds
    val ds: Dataset[User] = spark.read.json("input/user.json").as[User]

    // 特殊语法有两种写法  使用没有区别
    val ds1: Dataset[Row] = ds.select($"name", $"age" + 20 as "newAge")
      .where($"age" < 19)
      .sort("age")

    val ds2: Dataset[Row] = ds.select('name, 'age + 20 as 'newAge)
      .where('age < 19)
      .sort('newAge)

    ds2.show()

    // 分组聚合
    ds.groupBy($"name")
      .avg("age" )
      .select($"name",$"avg(age)" as "avg_age")
      .show()

    // 4.关闭sparkSession
    spark.stop()
  }
}

2.3 SQL语法的用户定义函数

2.3.1 UDF

一行进入,一行输出。

package com.atguigu.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, SparkSession}

object Test04_UDF {
  def main(args: Array[String]): Unit = {
    // 1 创建Spark Session的配置对象
    // local | yarn
    val conf: SparkConf = new SparkConf()
      .setAppName("sparkSQL")
      .setMaster("local[*]")
    // 2 获取SparkSession
    val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()

    // 导入特殊依赖,隐式转换
    import spark.implicits._

    // 3 代码
    // 读取数据
    val df: DataFrame = spark.read.json("input/user.json")
    // 创建临时视图
    df.createOrReplaceTempView("user")
    // 注册UDF函数
    spark.udf.register("addName", (x: String) => "Name: " + x)
    // 调用UDF函数
    spark.sql(
      """
        |select
        |   addName(name),
        |   age
        |from user
        |""".stripMargin).show()

    // 4 关闭SparkSession
    spark.stop()
  }
}

2.3.2 UDAF

输入多行,返回一行。通常和groupBy一起使用,如果直接使用UDAF函数,默认将所有的数据合并在一起。

Spark 3.x 推荐使用 extends Aggregaor 自定义UDAF,属于强类型的DataSet方式。

Spark 2.x 使用 extends UserDefinedAggregateFunction,属于弱类型的DataFrame。

package com.atguigu.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{DataFrame, Encoder, Encoders, SparkSession, functions}

object Test05_UDAF {
  def main(args: Array[String]): Unit = {
    // 1 创建Spark Session的配置对象
    // local | yarn
    val conf: SparkConf = new SparkConf()
      .setAppName("sparkSQL")
      .setMaster("local[*]")
    // 2 获取SparkSession
    val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()

    // 导入特殊依赖,隐式转换
    import spark.implicits._

    // 3 代码
    val df: DataFrame = spark.read.json("input/user.json")

    df.createOrReplaceTempView("user")

    spark.udf register("myAvg", functions.udaf(new MyAvgUDAF()))

    spark.sql(
      """
        |select
        |   myAvg(age)
        |from user
        |""".stripMargin).show()

    // 4 关闭SparkSession
    spark.stop()
  }
}

case class Buff(var sum: Long, var count: Long)

class MyAvgUDAF extends Aggregator[Long, Buff, Double] {

  // 初始化缓冲区
  override def zero: Buff = Buff(0L, 0L)

  // 将输入的年龄和缓冲区的数据进行聚合
  override def reduce(buff: Buff, age: Long): Buff = {
    buff.sum = buff.sum + age
    buff.count = buff.count + 1
    buff
  }

  // 多个缓冲区数据合并
  override def merge(buff1: Buff, buff2: Buff): Buff = {
    buff1.sum = buff1.sum + buff2.sum
    buff1.count = buff1.count + buff2.count
    buff1
  }

  // 完成聚合操作,获取最终结果
  override def finish(reduction: Buff): Double = {
    reduction.sum.toDouble / reduction.count
  }

  // SparkSQL对传递的对象的序列化操作(编码)
  // 自定义类型就是product,自带类型根据类型选择
  override def bufferEncoder: Encoder[Buff] = Encoders.product

  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

2.3.3 UDTF(没有)

输入一行,返回多行(Hive)。SparkSQL中没有UDFT,需要使用算子类型的flatMap先完成拆分。

3 SparkSQL数据的加载和保存

3.1 读取和保存文件

SparkSQL读取和保存的文件一般为三种,JSON文件、CSV文件和类似存储文件,同时可以通过添加参数,来识别不同的存储和压缩格式。

3.1.1 CSV文件

package com.atguigu.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, DataFrameReader, Dataset, SaveMode, SparkSession}

object Test06_csv {
  def main(args: Array[String]): Unit = {
    // 1 创建Spark Session的配置对象
    // local | yarn
    val conf: SparkConf = new SparkConf()
      .setAppName("sparkSQL")
      .setMaster("local[*]")
    // 2 获取SparkSession
    val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()

    // 导入特殊依赖,隐式转换
    import spark.implicits._

    // 3 代码
    val reader: DataFrameReader = spark.read

    val df: DataFrame = reader
      // 默认为 , 分割
      .option("sep", ";")
      // 默认为false,没有读取第一行为列名的功能
      .option("header", false)
      // 不需要填写压缩格式,自适应
      .csv("input/1.txt")

    // 添加列名才能转换为ds
    val df1: DataFrame = df.toDF("age", "name")
    val ds: Dataset[User] = df1.as[User]
    ds.show()

    ds.write
    // 默认false是否写出头信息
        .option("header",true)
    // 默认为写出数据的分隔符
        .option("sep",";")
        .option("compression","gzip")
    // 四种写出模式:append 追加 Ignore 忽略本次写出 Overwrite 覆盖写 ErrorIfExists 如果存在报错
        .mode(SaveMode.Append)
        .csv("output")

    // 4 关闭SparkSession
    spark.stop()
  }
}

// csv文件读取的数据需要用String接受
case class User(name: String, age: String)

3.1.2 JSON文件

package com.atguigu.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, DataFrameReader, Dataset, SparkSession}

object Test07_Json {
  def main(args: Array[String]): Unit = {
    // 1 创建Spark Session的配置对象
    // local | yarn
    val conf: SparkConf = new SparkConf()
      .setAppName("sparkSQL")
      .setMaster("local[*]")
    // 2 获取SparkSession
    val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()

    // 导入特殊依赖,隐式转换
    import spark.implicits._

    // 3 代码
    val df: DataFrame = spark.read.json("input/user.json")

    println(df.schema)
    df.show()

    val ds: Dataset[User] = df.as[User]

    ds.write
        .json("output")

    // 4 关闭SparkSession
    spark.stop()
  }
}

3.1.3 Parquet文件

列式存储的数据自带列分割。

package com.atguigu.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}

object Test08_Parquet {
  def main(args: Array[String]): Unit = {
    // 1 创建Spark Session的配置对象
    // local | yarn
    val conf: SparkConf = new SparkConf()
      .setAppName("sparkSQL")
      .setMaster("local[*]")
    // 2 获取SparkSession
    val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()

    // 导入特殊依赖,隐式转换
    import spark.implicits._

    // 3 代码
    val df: DataFrame = spark.read.json("input/user.json")

    val ds: Dataset[User] = df.as[User]
    // 写出默认使用snappy
    ds.write.parquet("output")
    // 字解析,性能优秀
    println(spark.read.parquet("output").schema)
    spark.read.parquet("output").show()

    // 4 关闭SparkSession
    spark.stop()
  }
}

3.2 与MySQL交互

导入依赖

<dependency>
    <groupId>mysql</groupId>
    <artifactId>mysql-connector-java</artifactId>
    <version>5.1.27</version>
</dependency>
package com.atguigu.spark.sql

import java.util.Properties

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}

object Test09_Table {
  def main(args: Array[String]): Unit = {
    // 1 创建Spark Session的配置对象
    // local | yarn
    val conf: SparkConf = new SparkConf()
      .setAppName("sparkSQL")
      .setMaster("local[*]")
    // 2 获取SparkSession
    val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()

    // 导入特殊依赖,隐式转换
    import spark.implicits._

    // 3 代码
    val df: DataFrame = spark.read.json("input/user.json")

    val properties = new Properties()
    properties.setProperty("user", "root")
    properties.setProperty("password", "123456")

    df.write
      // 如果是覆盖写模式
      .mode(SaveMode.Append)
      .jdbc("jdbc:mysql://hadoop102:3306", "gmall.testInfo", properties)

    // 读取mysql
    val df1: DataFrame = spark.read.jdbc("jdbc:mysql://hadoop102:3306", "gmall.user_info", properties)
    df1.show()
    println(df1.schema)


    // 4 关闭SparkSession
    spark.stop()
  }
}

3.3 与Hive交互

SparkSQL可以采用内嵌Hive(Spark开箱机用的Hive),也可以采用外部Hive。企业开发中,通常采用外部Hive。

添加依赖

<dependencies>
    <dependency>
        <groupId>org.apache.spark</groupId>
        <artifactId>spark-sql_2.12</artifactId>
        <version>3.1.3</version>
    </dependency>

    <dependency>
        <groupId>mysql</groupId>
        <artifactId>mysql-connector-java</artifactId>
        <version>5.1.27</version>
    </dependency>

    <dependency>
        <groupId>org.apache.spark</groupId>
        <artifactId>spark-hive_2.12</artifactId>
        <version>3.1.3</version>
    </dependency>
</dependencies>
package com.atguigu.sparksql

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, SparkSession}

object Test10_Hive {
  def main(args: Array[String]): Unit = {
    System.setProperty("HADOOP_USER_NAME","atguigu")
    // 1. 创建sparkSession的配置对象
    val conf: SparkConf = new SparkConf()
      .setAppName("sparkSql").setMaster("local[*]")

    // 2. 获取sparkSession
    val spark: SparkSession = SparkSession
      .builder().config(conf).enableHiveSupport().getOrCreate()
    import spark.implicits._
    // 3. 编写代码
    val df: DataFrame = spark
      .sql("show tables")

    df.show()

    spark.sql("insert into table user_info values('zhangsan',10)")
    spark.sql("""select * from user_info""").show()

    // 4.关闭sparkSession
    spark.stop()
  }
}

4 SparkSQL项目实战

需求:各区域热门商品Top3

使用Spark-SQL来完成复杂的需求,可以使用UDF或UDAF

  1. 查询出来所有的点击记录,并与city_info表连接,得到每个城市所在的地区,与 Product_info表连接得到商品名称。
  2. 按照地区和商品名称分组,统计出每个商品在每个地区的总点击次数。
  3. 每个地区内按照点击次数降序排列。
  4. 只取前三名,并把结果保存在数据库中。
  5. 城市备注需要自定义UDAF函数。
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession, functions}

import scala.collection.mutable
import scala.collection.mutable.ListBuffer

object Test_Top03 {
  def main(args: Array[String]): Unit = {
    // 1 创建Spark Session的配置对象
    // local | yarn
    val conf: SparkConf = new SparkConf()
      .setAppName("sparkSQL")
      .setMaster("local[*]")
    // 2 获取SparkSession
    val spark: SparkSession = SparkSession.builder().enableHiveSupport().config(conf).getOrCreate()

    // 导入特殊依赖,隐式转换
    import spark.implicits._

    spark.sql("use default")

    // 3 代码

    // 注册自定义聚合函数
    spark.udf.register("city_remark", functions.udaf(new CityRemarkUDAF()))

    // 1 查询点击记录并和城市表,产品表内连接
    spark.sql(
      """
        |select
        |   c.area,--地区
        |   c.city_name,--城市
        |   p.product_name,--商品名称
        |   v.click_product_id--点击商品id
        |from user_visit_action v
        |join city_info c
        |on v.city_id=c.city_id
        |join product_info p
        |on v.click_product_id=p.product_id
        |where click_product_id>-1
        |""".stripMargin).createOrReplaceTempView("t1")

    // 2 分组计算每个区域,每个产品的点击量
    spark.sql(
      """
        |select
        |   t1.area,
        |   t1.product_name,
        |   count(*) click_count,
        |   city_remark(t1.city_name) city_remark
        |from t1
        |group by t1.area,t1.product_name
        |""".stripMargin).show()

    // 4 关闭SparkSession
    spark.stop()
  }
}

// 中间缓存数据
case class Buffer(var totalcnt: Long, var cityMap: mutable.Map[String, Long])

class CityRemarkUDAF extends Aggregator[String, Buffer, String] {
  override def zero: Buffer = Buffer(0L, mutable.Map[String, Long]())

  override def reduce(buffer: Buffer, city: String): Buffer = {
    // 总点击次数
    buffer.totalcnt += 1

    // 每个城市的点击次数
    var newCount: Long = buffer.cityMap.getOrElse(city, 0L) + 1
    buffer.cityMap.update(city, newCount)

    buffer
  }

  override def merge(b1: Buffer, b2: Buffer): Buffer = {

    // 合并所有城市的点击数量的综合
    b1.totalcnt += b2.totalcnt

    // 合并城市Map(2个Map合并)
    b2.cityMap.foreach {
      case (city, count) => {
        val newCnt: Long = b1.cityMap.getOrElse(city, 0L) + count
        b1.cityMap.update(city, newCnt)
      }
    }
    b1
  }

  override def finish(reduction: Buffer): String = {
    val remarkList: ListBuffer[String] = ListBuffer[String]()

    //将统计的城市点击数量的集合进行排序,并取出前两名
    val cityCountList: List[(String, Long)] = reduction.cityMap.toList.sortWith {
      (left, right) => {
        left._2 > right._2
      }
    }.take(2)

    var sum: Long = 0L

    // 计算出前两名的百分比
    cityCountList.foreach {
      case (city, cnt) => {
        val r = cnt * 100 / reduction.totalcnt
        remarkList.append(city + " " + r + "%")
        sum += r
      }
    }

    // 如果城市个数大于2,用其他表示
    if (reduction.cityMap.size > 2) {
      remarkList.append("其他" + (100 - sum) + "%")
    }

    remarkList.mkString(",")
  }

  override def bufferEncoder: Encoder[Buffer] = Encoders.product

  override def outputEncoder: Encoder[String] = Encoders.STRING
}
// 简化版
package com.atguigu.sparksql

import org.apache.spark.SparkConf
import org.apache.spark.sql.{Encoder, Encoders, SparkSession, functions}
import org.apache.spark.sql.expressions.Aggregator

import scala.collection.mutable
import scala.collection.mutable.ListBuffer

object SparkSQL13_TopN_Layne {
  def main(args: Array[String]): Unit = {
    System.setProperty("HADOOP_USER_NAME", "atguigu")

    //TODO 1 创建SparkConf配置文件,并设置App名称
    val conf = new SparkConf().setAppName("SparkSQLTest").setMaster("local[*]")
    //TODO 2 利用SparkConf创建sparksession对象
    val spark: SparkSession = SparkSession.builder().enableHiveSupport().config(conf).getOrCreate()

    //注册自定义UDAF函数
spark.udf.register("city_remark", functions.udaf(new CityRemarkUDAF()))

    //执行sparksql
    spark.sql(
      """
        |select
        |    t3.area,
        |    t3.product_name,
        |    t3.click_count,
        |    t3.city_remark
        |from
        |(
        |    select
        |        t2.area,
        |        t2.product_name,
        |        t2.click_count,
        |        t2.city_remark,
        |        rank() over(partition by t2.area order by t2.click_count desc) rk
        |    from
        |    (
        |        select
        |            t1.area,
        |            t1.product_name,
        |            count(*) click_count,
        |            city_remark(t1.city_name) city_remark
        |        from
        |        (
        |            select
        |                c.area,
        |                c.city_name,
        |                p.product_name,
        |                v.click_product_id
        |            from user_visit_action v
        |            join city_info c
        |            on v.city_id = c.city_id
        |            join product_info p
        |            on v.click_product_id = p.product_id
        |            where v.click_product_id > -1
        |        )t1
        |        group by t1.area,t1.product_name
        |    ) t2
        |) t3
        |where t3.rk <= 3
        |""".stripMargin).show(1000, false)

    // 3 关闭资源
    spark.stop()
  }
}

/**
 * 输入: city_name    String
 * 缓冲区: 这个大区的总点击量  Map[(City_name,城市的点击量)]   Buffer
 * 输出: city_remark   String
 */
case class Buffer(var totalcnt: Long, var cityMap: mutable.Map[String, Long])

class CityRemarkUDAF extends Aggregator[String, Buffer, String] {

  override def zero: Buffer = Buffer(0L, mutable.Map[String, Long]())

  //单个分区聚合方法  buffer 和   city
  override def reduce(buffer: Buffer, city: String): Buffer = {
    buffer.totalcnt += 1
    buffer.cityMap(city) = buffer.cityMap.getOrElse(city, 0L) + 1
    buffer
  }

  //多个buffer之间的聚合
  override def merge(b1: Buffer, b2: Buffer): Buffer = {
    b1.totalcnt += b2.totalcnt
    b2.cityMap.foreach {
      case (city, cityCnt) => {
        b1.cityMap(city) = b1.cityMap.getOrElse(city, 0L) + cityCnt
      }
    }
    b1
  }

  //最终逻辑计算方法
  override def finish(buffer: Buffer): String = {
    //0 定义一个listBuffer 用来存储最后返回结果
    val remarkList = ListBuffer[String]()

    //1 取出buffer的map,然后转成list,按照城市点击量 倒序排序
    val cityList: List[(String, Long)] = buffer.cityMap.toList.sortWith(
      (t1, t2) => t1._2 > t2._2
    )

    //2 取出排好序的cityList前两名,特殊处理
    var sum = 0L

    cityList.take(2).foreach {
      case (city, cityCnt) => {
        val res: Long = cityCnt * 100 / buffer.totalcnt
        remarkList.append(city + " " + res + "%")
        sum += res   //将前两个加起来,方便计算第三个其他百分比
      }
    }

    //3 计算第三个其他的百分比
    if (buffer.cityMap.size > 2) {
      remarkList.append("其他 " + (100 - sum) + "%")
    }

    //4 返回remarkList字符串
    remarkList.mkString(",")
  }

  override def bufferEncoder: Encoder[Buffer] = Encoders.product

  override def outputEncoder: Encoder[String] = Encoders.STRING
}
// 算子实现自定义函数

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Dataset, Encoders, KeyValueGroupedDataset, Row, SparkSession}

import scala.collection.mutable
import scala.collection.mutable.ListBuffer

object Test02_Top3 {
  def main(args: Array[String]): Unit = {
    // 1 创建Spark Session的配置对象
    // local | yarn
    val conf: SparkConf = new SparkConf()
      .setAppName("sparkSQL")
      .setMaster("local[*]")
    // 2 获取SparkSession
    val spark: SparkSession = SparkSession.builder().enableHiveSupport().config(conf).getOrCreate()

    spark.sql("use default")

    // 导入特殊依赖,隐式转换
    import spark.implicits._

    // 3 代码
    // 读取数据使用SQL最方便
    val df: DataFrame = spark.sql(
      """
        | select
        |    c.area, --地区
        |    c.city_name, -- 城市
        |    p.product_name, -- 商品名称
        |    v.click_product_id -- 点击商品id
        | from user_visit_action v
        | join city_info c
        | on v.city_id = c.city_id
        | join product_info p
        | on v.click_product_id = p.product_id
        | where click_product_id > -1
        |""".stripMargin)

    val value: KeyValueGroupedDataset[(String, String), Row] = df.groupByKey(row => {
      (row.getString(0), row.getString(2))
    })

    // tuple:(area, product_name)
    val groupMarkDS: Dataset[AreaProductInfo] = value.mapGroups((tuple: (String, String), it: Iterator[Row]) => {
      // 声明总数 和hashMap统计城市百分比
      var totalCount = 0L
      val hashMap = new mutable.HashMap[String, Long]()

      // 循环遍历一组的每一行数据
      it.foreach(row => {
        totalCount += 1
        val cityName: String = row.getString(1)
        hashMap.put(cityName, 1L + hashMap.getOrElse(cityName, 0L))
      })
      // 城市标记
      val remarkList = ListBuffer[String]()
      // 将统计的城市点击数量的集合进行排序,并取出前两名
      val cityCountList: List[(String, Long)] = hashMap.toList.sortWith(
        (left, right) => {
          left._2 > right._2
        }
      ).take(2)

      var sum: Long = 0L

      // 计算出前两名的百分比
      cityCountList.foreach {
        case (city, cnt) => {
          val r = cnt * 100 / totalCount
          remarkList.append(city + " " + r + "%")
          sum += r
        }
      }

      // 如果城市个数大于2,用其他表示
      if (hashMap.size > 2) {
        remarkList.append("其他 " + (100 - sum) + "%")
      }

      val remarkStr: String = remarkList.mkString(",")
      // 创建返回值
      AreaProductInfo(tuple._1,tuple._2,totalCount,remarkStr)
    })(Encoders.product[AreaProductInfo])

    groupMarkDS.toDF().createOrReplaceTempView("t2")

    spark.sql(
      """
        |select
        |   *,
        |   rank() over(partition by t2.area order by t2.click_count desc) rank
        |from t2
        |""".stripMargin).createOrReplaceTempView("t3")

    spark.sql(
      """
        |select
        |   *
        |from t3
        |where rank<=3
        |""".stripMargin).show()

    // 4 关闭SparkSession
    spark.stop()

  }
}

case class AreaProductInfo(area: String, product_name: String, click_count: Long, cityMark: String)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值