第一章.SparkSQL编程
1.RDD、DataFrame、DataSet相互转换
package com.atguigu.spark.day08
import org.junit.Test
/**
* RDD,DataFrame,DataSet相互转换
* rdd转DataFrame:通过toDF方法
* rdd转DataSet:通过toDS方法
* DataFrame转RDD: df.rdd
* DataSet转RDD: ds.rdd
* DataSet转DataFrame: ds.toDF(..)
* DataFrame转DataSet: as[行的类型]
* DataFrame转DataSet的时候,行的类型一般是写元组或者样例类
* 行的类型是元组,此时元组的元素个数要和列的个数一致,类型也要一致
* 行的类型是样例类,此时样例类属性的个数不能大于列的个数,属性名要与列名要一致
*
* Row类型的取值: row.getAs[列的类型](列名)
*/
class $01_RDDDateFramaDataSet {
import org.apache.spark.sql.SparkSession
val spark = SparkSession.builder().master("local[4]").appName("test").getOrCreate()
import spark.implicits._
@Test
def cover():Unit={
val rdd = spark.
sparkContext.parallelize(List((1,"zhangsan"),2->"lisi",3->"wangwu"))
//RDD转DataFrame:通过toDF方法
val df = rdd.toDF("id", "name")
df.show
//RDD转DataSet:通过toDS方法
val ds = rdd.toDS()
ds.show()
//dataFrame转RDD: df.rdd
val rdd2 = df.rdd
val rdd3 = rdd2.map(row => {
//row类型取值
val name = row.getAs[String]("name")
name
})
println(rdd3.collect().toList)
//dateset转rdd: ds.rdd
val rdd4 = ds.rdd
println(rdd4.collect().toList)
//dataset转dataframe: ds.toDF(..)
val df2 = ds.toDF("id", "name")
df2.show()
//dateFrame转dateSet:as[行的类型]
val ds3 = df.as[(Int, String)]
ds3.show()
val ds4 = df.as[AA]
ds4.show()
}
}
case class AA()
2.DataFrame与DataSet的区别
package com.atguigu.spark.day08
import com.atguigu.spark.day07.Person
import org.junit.Test
/**
* DataFrame与DataSet的区别:
* 1.DataFrame是只关注列的信息,不关注行的类型,是弱类型
* DataSet即关注行也关注列,是强类型
* 2.DataFrame是运行期安全,编译期不安全
* DataSet是运行期和编译期都安全
*
*DataFrame与DataSet的使用场景
* 1.如果需要将RDD转成SparkSQL操作
* 如果RDD中元素类型是元组,此时推荐使用toDF重定义列名转成DataFrame
* 如果RDD中元素类型是样例类,此时可以随意转换
* 2.如果需要重定义列名推荐使用toDF重定义列名转成DataFrame
* 3.如果需要使用map,flatMap这种强类型算子,推荐使用DataSet
*
*/
class $02_DataFrameDataSet {
import org.apache.spark.sql.SparkSession
val spark = SparkSession.builder().master("local[4]").appName("test").getOrCreate()
import spark.implicits._
@Test
def diff():Unit={
val list = List(
Person(1,"lisi1",21,"shenzhen"),
Person(2,"lisi2",22,"beijing"),
Person(2,"lisi2",22,"beijing"),
Person(2,"lisi2",30,"beijing"),
Person(3,"lisi3",23,"tianj"),
Person(4,"lisi4",24,"shanghai"),
Person(6,"lisi4",35,"shenzhen"),
Person(7,"lisi4",29,"hangzhou"),
Person(8,"lisi4",30,"guangzhou")
)
val df = list.toDF()
df.where("age>20").show()
df.map(row=>row.getAs[String]("name")).show()
val ds = list.toDS()
//ds.map(x=>x.xx)
}
}
3.自定义UDF函数
package com.atguigu.spark.day08
/**
* UDF:一进一出
* UDAF:多进一出
* UDTF:一进多出[spark没有]
*
* spark中自定义UDF函数
* 1.定义一个函数
* 2.将函数注册成udf函数
* 3.使用
*/
object $03_UDF {
def main(args: Array[String]): Unit = {
import org.apache.spark.sql.SparkSession
val spark = SparkSession.builder().master("local[4]").appName("test").getOrCreate()
import spark.implicits._
val df = spark.sparkContext.parallelize(List(
("10001", "zhangsan"),
("00102", "zhangsan"),
("111000", "zhangsan"),
("010", "zhangsan"),
("00560", "zhangsan")
)).toDF("id", "name")
//将数据集注册成表
df.createOrReplaceTempView("person")
//注册udf函数
spark.udf.register("xxx",prfixed _)
//需求:员工id不满8位,员工id前面以0补齐
spark.sql(
"""
|select xxx(id),name from person
|""".stripMargin
).show()
}
def prfixed(id:String):String={
val currentLength = id.length
"0" * (8-currentLength) +id
}
}
4.自定义UDAF函数(弱类型)
- MyAvgWeakType.scala
package com.atguigu.spark.day08
import org.apache.parquet.filter2.predicate.Operators.UserDefined
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, StructField, StructType}
/**
* spark2.xx版本
* 自定义UDAF函数(弱类型)
*/
class MyAvgWeakType extends UserDefinedAggregateFunction{
//自定义UDAF的参数类型
override def inputSchema: StructType = {
/*
第一种方式
val fields = Array[StructField](
StructField("input", IntegerType)
)
val schema = StructType(fields)
schema*/
//第二种方式
new StructType().add("input",IntegerType)
}
//定义中间变量类型
override def bufferSchema: StructType = {
new StructType().add("sum",IntegerType).add("count",IntegerType)
}
//最终结果类型
override def dataType: DataType = DoubleType
//设置一致性
override def deterministic: Boolean = true
//初始化中间变量
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//初始化sum
buffer(0)=0
//初始化count
buffer(1)=0
}
/**
* combine阶段对每个组进行聚合
* @param buffer
* @param input
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//每次进来一个数据,sum+age
buffer(0) = buffer.getAs[Int](0) + input.getAs[Int](0)
//每次进来一个年龄,count+1
buffer(1) = buffer.getAs[Int](1) + 1
}
/**
* 在reduce阶段全局汇总
* @param buffer1
* @param buffer2
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//将多个分区的sum汇总
buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
//将多个分区的count汇总
buffer1(1) = buffer1.getAs[Int](1) + buffer2.getAs[Int](1)
}
//计算最终结果
override def evaluate(buffer: Row): Any = {
buffer.getAs[Int](0).toDouble / buffer.getAs[Int](1)
}
}
- UDAF.scala
package com.atguigu.spark.day08
import com.atguigu.spark.day07.Person
object $04_UDAF {
def main(args: Array[String]): Unit = {
import org.apache.spark.sql.SparkSession
val spark = SparkSession.builder().master("local[4]").appName("test").getOrCreate()
import spark.implicits._
val list = List(
Person(1,"lisi1",21,"shenzhen"),
Person(2,"lisi2",22,"beijing"),
Person(2,"lisi2",22,"beijing"),
Person(2,"lisi2",30,"beijing"),
Person(3,"lisi3",23,"tianjin"),
Person(4,"lisi4",24,"shanghai"),
Person(6,"lisi4",35,"shenzhen"),
Person(7,"lisi4",29,"hangzhou"),
Person(8,"lisi4",30,"guangzhou")
)
val df = list.toDF()
df.createOrReplaceTempView("person")
//注册udaf函数(弱类型)
spark.udf.register("myAvg",new MyAvgWeakType)
spark.sql(
"""
|select address,myAvg(age) avg_age from person group by address
|""".stripMargin
).show()
}
}
5.自定义UDAF函数(强类型)
- MyAvgStronglyType.scala
package com.atguigu.spark.day08
import org.apache.spark.sql.{Encoder,Encoders}
import org.apache.spark.sql.expressions.Aggregator
/**
* spark3.xx版本
* 自定义UDAF函数(强类型)
*
*/
case class AvgBuff(var sum:Int,var count:Int)
class MyAvgStronglyType extends Aggregator[Int,AvgBuff,Double]{
//初始化中间变量
override def zero: AvgBuff = {
AvgBuff(0,0)
}
//在combine阶段的聚合逻辑
override def reduce(b: AvgBuff, a: Int): AvgBuff = {
b.sum = b.sum + a
b.count = b.count + 1
}
//在reduce阶段的聚合逻辑
override def merge(b1: AvgBuff, b2: AvgBuff): AvgBuff = {
b1.sum = b1.sum + b2.sum
b1.count = b1.count + b2.count
b1
}
//计算最终结果
override def finish(reduction: AvgBuff): Double = {
reduction.sum.toDouble / reduction.count
}
//指定中间变量的编码格式
override def bufferEncoder: Encoder[AvgBuff] = Encoders.product[AvgBuff]
//指定最终结果类型的编码格式
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
- UDAF.scala
package com.atguigu.spark.day08
import com.atguigu.spark.day07.Person
object $04_UDAF {
def main(args: Array[String]): Unit = {
import org.apache.spark.sql.SparkSession
val spark = SparkSession.builder().master("local[4]").appName("test").getOrCreate()
import spark.implicits._
val list = List(
Person(1,"lisi1",21,"shenzhen"),
Person(2,"lisi2",22,"beijing"),
Person(2,"lisi2",22,"beijing"),
Person(2,"lisi2",30,"beijing"),
Person(3,"lisi3",23,"tianjin"),
Person(4,"lisi4",24,"shanghai"),
Person(6,"lisi4",35,"shenzhen"),
Person(7,"lisi4",29,"hangzhou"),
Person(8,"lisi4",30,"guangzhou")
)
val df = list.toDF()
df.createOrReplaceTempView("person")
/*
注册udaf函数(弱类型)
spark.udf.register("myAvg",new MyAvgWeakType)
*/
//注册udaf函数(强类型)
import org.apache.spark.sql.functions._
spark.udf.register("myAvg",udaf(new MyAvgStronglyType))
spark.sql(
"""
|select address,myAvg(age) avg_age from person group by address
|""".stripMargin
).show()
}
}
第二章.SparkSQL数据的加载与保存
1.读取文件
package com.atguigu.spark.day08
import java.util.Properties
import org.junit.Test
import org.apache.spark.sql.SaveMode
class $05_Reader {
import org.apache.spark.sql.SparkSession
val spark = SparkSession.builder().master("local[4]").appName("test").getOrCreate()
import spark.implicits._
/**
* 文件读取方式:
* 1.spark.read
* .format(text/json/csv/jdbc/orc/parquet) --指定文件格式
* .option(K,V) --设置读取的参数
* .load(path) --加载数据
* 2.spark.read.textFile/json/orc/csv/parquet
*/
@Test
def read():Unit={
/*
读取文本数据
spark.read.textFile("datas/wc.txt").show()
spark.read.format("text").load("datas/wc.txt").show()
*/
/*
读取json数据
spark.read.json("datas/pmt.json").show()
spark.read.format("json").load("datas/pmt.json").show()
*/
/**
* 读取csv数据
* 常用option:
* sep:设置字段之间的分割符
* header:是否以文件的第一行作为列名
* inferSchema:是否推断列的类型
*
* spark.read.option("header","true").option("inferSchema","true")
* .csv("datas/presidential_polls.csv").printSchema()
* spark.read.format("csv").option("header","true")
* .option("inferSchema","true").load("datas/presidential_polls.csv").show()
*
*/
/*
保存为parquet文件
spark.read.format("csv").option("header", "true")
.option("inferSchema", "true")
.load("datas/presidential_polls.csv")
.write.mode(SaveMode.Overwrite).parquet("output/parquet")
*/
/*
读取parquet文件
spark.read.load("output/parquet").show()
spark.read.format("parquet").load("output/parquet").show
*/
}
}
2.读取jdbc
/**
* 读取mysql数据
*/
@Test
def readJdbc():Unit={
//第一种方式
spark.read.format("jdbc")
.option("url","jdbc:mysql://hadoop102:3306/gmall")
.option("dbtable","user_info")
.option("user","root")
.option("password","321074")
.load()
.show()
//第二种方式
//此种方式读取mysql只会生成一个分区<只用于小数据量场景>
val url = "jdbc:mysql://hadoop102:3306/gmall"
val tableName = "user_info"
val props = new Properties()
props.setProperty("user","root")
props.setProperty("password","321074")
val df = spark.read.jdbc(url, tableName, props)
df.show()
println(df.rdd.getNumPartitions)
//此种方式读取mysql的分区数 = 数组中where条件的个数<不用>
val condition =Array("id<20","id>=20 and id<50","id>=50")
val df2 = spark.read.jdbc(url,tableName,condition,props)
println(df2.rdd.getNumPartitions)
/*
第三种方式(常用)
columnName必须是数字,日期,时间戳类型的列名
此种方式读取的mysql分区数 = upperBound-lowerBound >= numPartitions ? numPartition : upperBound - lowerBound
*/
//动态获取lowerBound与upperBound
val minDF = spark.read.jdbc(url,"(select min(id) min_id from user_info) user_min_id",props)
val minRdd = minDF.rdd
val minid = minRdd.collect().head.getAs[Long]("min_id")
val maxDF = spark.read.jdbc(url,"(select max(id) max_id from user_info) user_max_id",props)
val maxRdd = maxDF.rdd
val maxid = maxRdd.collect().head.getAs[Long]("max_id")
println(minid,maxid)
val df3 = spark.read.jdbc(url,tableName,"id",minid,maxid,5,props)
println(df3.rdd.getNumPartitions)
}
3.读取mysql的分区数源码
def columnPartition(
schema: StructType,
resolver: Resolver,
timeZoneId: String,
jdbcOptions: JDBCOptions): Array[Partition] = {
val partitioning = {
import JDBCOptions._
val partitionColumn = jdbcOptions.partitionColumn
//partitionColumn = "id"
val lowerBound = jdbcOptions.lowerBound
// lowerBound = 1
val upperBound = jdbcOptions.upperBound
// upperBound = 100
val numPartitions = jdbcOptions.numPartitions
//numPartitions = 5
//没有指定分区列
if (partitionColumn.isEmpty) {
assert(lowerBound.isEmpty && upperBound.isEmpty, "When 'partitionColumn' is not " +
s"specified, '$JDBC_LOWER_BOUND' and '$JDBC_UPPER_BOUND' are expected to be empty")
null
} else {
//有指定分区列
assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty,
s"When 'partitionColumn' is specified, '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', and " +
s"'$JDBC_NUM_PARTITIONS' are also required")
//判断分区列的类型是否为数字、日期、时间戳类型,如果不是则抛错
val (column, columnType) = verifyAndGetNormalizedPartitionColumn(
schema, partitionColumn.get, resolver, jdbcOptions)
//
val lowerBoundValue = toInternalBoundValue(lowerBound.get, columnType, timeZoneId)
//lowerBoundValue = 1L
val upperBoundValue = toInternalBoundValue(upperBound.get, columnType, timeZoneId)
//upperBoundValue = 100L
JDBCPartitioningInfo(
column, columnType, lowerBoundValue, upperBoundValue, numPartitions.get)
}
}
//如果分区列没有指定则只分配一个分区
if (partitioning == null || partitioning.numPartitions <= 1 ||
partitioning.lowerBound == partitioning.upperBound) {
return Array[Partition](JDBCPartition(null, 0))
}
val lowerBound = partitioning.lowerBound
//lowerBound = 1
val upperBound = partitioning.upperBound
//upperBound = 100
val boundValueToString: Long => String =
toBoundValueInWhereClause(_, partitioning.columnType, timeZoneId)
//分区数 = (upperBound - lowerBound) >= partitioning.numPartitions ? partitioning.numPartitions : upperBound - lowerBound
val numPartitions =
if ((upperBound - lowerBound) >= partitioning.numPartitions || /* check for overflow */
(upperBound - lowerBound) < 0) {
partitioning.numPartitions
} else {
logWarning("The number of partitions is reduced because the specified number of " +
"partitions is less than the difference between upper bound and lower bound. " +
s"Updated number of partitions: ${upperBound - lowerBound}; Input number of " +
s"partitions: ${partitioning.numPartitions}; " +
s"Lower bound: ${boundValueToString(lowerBound)}; " +
s"Upper bound: ${boundValueToString(upperBound)}.")
upperBound - lowerBound
}
// 计算每个分区的数据的步长 = 100 / 5 - 1/5 = 20
val stride: Long = upperBound / numPartitions - lowerBound / numPartitions
var i: Int = 0
//column = "id"
val column = partitioning.column
var currentValue = lowerBound
//currentValue = 1
//创建一个存储分区的容器
val ans = new ArrayBuffer[Partition]()
while (i < numPartitions) {
//第一次遍历 i = 0 numPartitions=5 currentValue=1
// lBoundValue = "1"
// lBound = null
// currentValue = currentValue + stride = 1 + 20 = 21
// uBoundValue = "21"
// uBound = s"id < 21"
// whereClause = s"id < 21 or id is null"
//第二次遍历 i = 1 numPartitions=5 currentValue=21
// lBoundValue = "21"
// lBound = s"id >= 21"
// currentValue = currentValue + stride = 21 + 20 = 41
// uBoundValue = "41"
// uBound = s"id < 41"
// whereClause = "id >= 21 and id < 41"
val lBoundValue = boundValueToString(currentValue)
val lBound = if (i != 0) s"$column >= $lBoundValue" else null
currentValue += stride
val uBoundValue = boundValueToString(currentValue)
val uBound = if (i != numPartitions - 1) s"$column < $uBoundValue" else null
val whereClause =
if (uBound == null) {
lBound
} else if (lBound == null) {
s"$uBound or $column is null"
} else {
s"$lBound AND $uBound"
}
ans += JDBCPartition(whereClause, i)
i = i + 1
}
val partitions = ans.toArray
logInfo(s"Number of partitions: $numPartitions, WHERE clauses of these partitions: " +
partitions.map(_.asInstanceOf[JDBCPartition].whereClause).mkString(", "))
partitions
}
4.保存数据
@Test
def write():Unit={
val df = spark.read.json("datas/pmt.json")
//保存为文本
val ds = df.toJSON
//ds.write.mode(SaveMode.Overwrite).text("output/text")
//ds.write.mode(SaveMode.Overwrite).format("text").save("output/text1")
//保存为json
//df.write.mode(SaveMode.Overwrite).format("json").save("output/json")
//df.write.mode(SaveMode.Overwrite).json("output/text")
//保存为parquet
//df.write.mode(SaveMode.Overwrite).format("parquet").save("output/parquet")
//df.write.mode(SaveMode.Overwrite).parquet("output/parquet")
//保存为csv
//df.write.mode(SaveMode.Overwrite).option("sep","#").option("header","true").format("csv").save("output/csv")
//df.write.mode(SaveMode.Overwrite).option("sep","#").option("header","true").csv("output/csv")
//保存数据到mysql
val props = new Properties()
props.setProperty("user","root")
props.setProperty("password","root123")
//df.write.mode(SaveMode.Append).jdbc("jdbc:mysql://hadoop102:3306/test","xx",props)
//上面直接写入数据到mysql的时候可能出现主键冲突的问题,此时需要使用foreachPartitions,自己使用 INSERT INTO xx VALUES (..) ON DUPLICATE KEY UPDATE(....) 更新数据。
//df.rdd.foreachPartition(x=> //)
}
5.Spark整合Hive
SparkSQL可以采用内嵌Hive,也可以采用外部Hive,企业开发中,通常采用外部Hive
一.内嵌Hive应用
直接进入spark-yarn,直接使用spark-shell即可
执行完后,发现多了metastore_db和derby.log,用于存储元数据,spark-warehouse,用于存储数据库数据
然而在实际使用中,几乎没有任何人会使用内置的Hive,因为元数据存储在derby数据库,不支持多客户端访问
二.外部Hive应用
如果spark要接管Hive外部已经部署好的hive,需要通过以下几个步骤
- 为了说明内嵌hive与外部hive的区别:删除内嵌hive的metastore_db , spark_warehouse
rm -rf spark-warehouse/ metastore_db/
- 将hive-site.xml拷贝到spark的conf目录下
cp /opt/module/hive-3.1.2/conf/hive-site.xml conf/
- 把mysql的驱动包拷贝到spark的jars目录下
cp /opt/software/mysql-connector-java-8.0.19.jar ./jars
- 启动spark-sql
bin/spark-sql
- idea操作Hive
一.添加spark-hive依赖
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_2.12</artifactId>
<version>3.0.0</version>
</dependency>
二.拷贝hive-site.xml到resources目录
三.编写代码
package com.atguigu.spark.day08
object $06_SparkHive {
def main(args: Array[String]): Unit = {
System.setProperty("HADOOP_USER_NAME","atguigu")
import org.apache.spark.sql.SparkSession
val spark = SparkSession.builder()
.master("local[4]").
appName("test")
//开启hive支持
.enableHiveSupport()
.getOrCreate()
import spark.implicits._
spark.sql("select * from student").show()
}
}