我们看下如何在 SparkSQL 中 定义并使用 UDTF。
Base spark 2.2.0
Base Hive 2.1.1
历史方案
Spark 1.*
Hive 2.1.1
目前Spark 内部不直接支持 udtf, 在比较久远的版本 spark 1.* 可以实现通过 hive 的 UDTF, 并注册函数实现。
UDTF class
package com.spark.test.offline.udf
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
/**
* Created by szh on 2020/6/1.
*/
class CustomerUDTF extends GenericUDTF {
override def process(objects: Array[AnyRef]): Unit = {
//将字符串切分成单个字符的数组
val strLst = objects(0).toString.split("")
for (i <- strLst) {
var tmp: Array[String] = new Array[String](1)
tmp(0) = i
//调用forward方法,必须传字符串数组,即使只有一个元素
forward(tmp)
}
}
override def close(): Unit = {}
}
spark 中代码
spark.sqlContext.sql("CREATE TEMPORARY FUNCTION NEWUDTF as 'com.spark.test.offline.udf.CustomerUDTF'")
相关参考文章:
1.error running Hive temporary UDTF on latest Spark 2.2
https://issues.apache.org/jira/browse/SPARK-21101
2.SparkSQL 自定义算子UDF、UDAF、UDTF
https://blog.csdn.net/laksdbaksjfgba/article/details/87162906
解决方案
实现的逻辑:由1条记录生成10条记录
使用 ftatMap 算子对Row 进行膨胀
完整代码如下:
package com.spark.test.offline.udf
import com.spark.test.offline.optimize.del.User
import org.apache.spark.SparkConf
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}
import scala.collection.mutable.ArrayBuffer
/**
* Created by szh on 2020/6/1.
*/
object SparkSQLUdtf {
def main(args: Array[String]): Unit = {
val conf = new SparkConf
conf
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
// .set("spark.kryo.registrationRequired", "true")
//方法一
.registerKryoClasses(
Array(
classOf[User]
, classOf[scala.collection.mutable.WrappedArray.ofRef[_]]
))
val spark = SparkSession
.builder()
.appName("sparkSql")
.master("local[1]")
.config(conf)
.getOrCreate()
val sc = spark.sparkContext
sc.setLogLevel("ERROR")
val orgRDD = sc.parallelize(Seq(
User(1, "cc", "bj")
, User(2, "aa", "bj")
, User(3, "qq", "bj")
, User(4, "pp", "bj")
))
val orgDF = spark
.createDataFrame(orgRDD)
orgDF.show()
//spark.sqlContext.sql("CREATE TEMPORARY FUNCTION NEWUDTF as 'com.spark.test.offline.udf.CustomerUDTF'")
val midRDD = orgDF.rdd.flatMap(tmp => {
val x = ArrayBuffer[Row]()
for (i <- 1 to 10) {
x.+=:(Row(tmp.getInt(0), tmp.getString(1), tmp.getString(2), i))
}
x
})
println(midRDD.count())
val finalDF = spark.createDataFrame(midRDD, StructType(
Array(
StructField("id", IntegerType)
, StructField("name", StringType)
, StructField("city", StringType)
, StructField("no", IntegerType)
)
))
finalDF.show()
Thread.sleep(10 * 60 * 1000)
sc.stop()
spark.stop()
}
}
代码如下
+---+----+----+
| id|name|city|
+---+----+----+
| 1| cc| bj|
| 2| aa| bj|
| 3| qq| bj|
| 4| pp| bj|
+---+----+----+
40
+---+----+----+---+
| id|name|city| no|
+---+----+----+---+
| 1| cc| bj| 10|
| 1| cc| bj| 9|
| 1| cc| bj| 8|
| 1| cc| bj| 7|
| 1| cc| bj| 6|
| 1| cc| bj| 5|
| 1| cc| bj| 4|
| 1| cc| bj| 3|
| 1| cc| bj| 2|
| 1| cc| bj| 1|
| 2| aa| bj| 10|
| 2| aa| bj| 9|
| 2| aa| bj| 8|
| 2| aa| bj| 7|
| 2| aa| bj| 6|
| 2| aa| bj| 5|
| 2| aa| bj| 4|
| 2| aa| bj| 3|
| 2| aa| bj| 2|
| 2| aa| bj| 1|
+---+----+----+---+
only showing top 20 rows
Maven
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>spark-test</artifactId>
<groupId>www.sunzhenhua.com</groupId>
<version>1.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>spark-offline</artifactId>
<dependencies>
<!-- spark community -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.11</artifactId>
<!--<exclusions>-->
<!--<exclusion>-->
<!--<groupId>org.slf4j</groupId>-->
<!--<artifactId>slf4j-api</artifactId>-->
<!--</exclusion>-->
<!--<exclusion>-->
<!--<groupId>org.slf4j</groupId>-->
<!--<artifactId>slf4j-log4j12</artifactId>-->
<!--</exclusion>-->
<!--</exclusions>-->
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId>
<!--<exclusions>-->
<!--<exclusion>-->
<!--<groupId>org.slf4j</groupId>-->
<!--<artifactId>slf4j-api</artifactId>-->
<!--</exclusion>-->
<!--<exclusion>-->
<!--<groupId>org.slf4j</groupId>-->
<!--<artifactId>slf4j-log4j12</artifactId>-->
<!--</exclusion>-->
<!--</exclusions>-->
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_2.11</artifactId>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming-kafka-0-10_2.11</artifactId>
</dependency>
<!-- scala -->
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
</dependency>
<dependency>
<groupId>org.apache.hive</groupId>
<artifactId>hive-exec</artifactId>
<version>2.1.0</version>
</dependency>
<!--<dependency>-->
<!--<groupId>org.apache.hive</groupId>-->
<!--<artifactId>hive-common</artifactId>-->
<!--<version>2.1.1</version>-->
<!--</dependency>-->
</dependencies>
<build>
<resources>
<resource>
<directory>src/main/resources</directory>
<filtering>true</filtering>
</resource>
</resources>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
<testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<minimizeJar>false</minimizeJar>
<shadedArtifactAttached>true</shadedArtifactAttached>
<artifactSet>
<includes>
<!-- Include here the dependencies you
want to be packed in your fat jar -->
<include>*:*</include>
</includes>
</artifactSet>
<filters>
<filter>
<artifact>*:*</artifact>
<excludes>
<exclude>META-INF/*.SF</exclude>
<exclude>META-INF/*.DSA</exclude>
<exclude>META-INF/*.RSA</exclude>
</excludes>
</filter>
</filters>
<transformers>
<transformer
implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
<resource>reference.conf</resource>
</transformer>
</transformers>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>