Spark自定义函数
spark 中的 UDF (UserDefinedFunction) 大家都不会陌生, UDF 其实就是将一个普通的函数, 包装为可以按 “行“ 操作的函数, 用来处理 DataFrame 中指定的 Columns.
例如, 对某一列的所有元素进行 +1 操作, 它对应 mapreduce 操作中的 map 操作. 这种操作有的主要特点是: 行与行之间的操作是 独立 的, 可以非常方便的 并行计算 每一行的操作完成后, map 的任务就完成了, 直接将结果返回就行, 它是一种”无状态的“
但是 UDAF (UserDefinedAggregateFunction) 则不同, 由于存在聚合 (Aggregate) 操作, 它对应 mapreduce 操作中的 reduce 操作. SparkSQL中有很多现成的聚合函数, 常用的 sum, count, avg 等等都是.
这种操作的主要特点是: 每一轮 reduce 之间可以是并行, 但是多轮 reduce 的执行是 串行 的, 下一轮依靠前一轮的结果, 它是一种“有状态的”, 需要记录中间的计算结果
import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions
import java.util
object udf_manage {
/**
* dt yyyy-MM-dd
*/
val getQuarterMapFunction: UserDefinedFunction = functions.udf(
(dt: String) => {
StringUtil.assertNotBlank(dt, "dt is empty!!!")
val month = dt.split("-")(1)
month match {
case "01" | "02" | "03" => "1"
case "04" | "05" | "06" => "2"
case "07" | "08" | "09" => "3"
case "10" | "11" | "12" => "4"
case _ => throw new RuntimeException(s"不支持的日期:$dt")
}
}
)
/**
* yyyy-MM-dd HH:mm:ss
*/
val getDtMapFunction: UserDefinedFunction = functions.udf((acquisitionTime: String) => {
val dt = acquisitionTime.split(" ")(0)
dt
})
val getDhMapFunction: UserDefinedFunction = functions.udf((acquisitionTime: String) => {
val dh = acquisitionTime.split(" ")(1).split(":")(0)
dh
})
val getDmMapFunction: UserDefinedFunction = functions.udf((acquisitionTime: String) => {
val dm = acquisitionTime.split(" ")(1).split(":")(1).toInt
if (dm >= 0 && dm < 15) {
"00"
} else if (dm >= 15 && dm < 30) {
"15"
} else if (dm >= 30 && dm < 45) {
"30"
} else {
"45"
}
})
val nullMapFunction: UserDefinedFunction = functions.udf(
(str: String) => {
val r = str match {
case null | "" => "NULL"
case _ => str
}
r
}
)
val natureMapFunction: UserDefinedFunction = functions.udf(
(project_nature: String) => {
val r = project_nature match {
case "366" | "368" | "385" | "386" => project_nature
case _ => "378"
}
r
}
)
val monthMapFunction: UserDefinedFunction = functions.udf(
(cost_month: String, default: String) => {
//202305
if (StringUtils.isNotBlank(cost_month) && cost_month.length == 6) {
val year = cost_month.substring(0, 4)
val month = cost_month.substring(4)
s"$year-$month-01"
} else {
default
}
})
/**
* a,b,c,c,d
* 这类以,进行拼接的string的去重计数
*/
val idsCntsUDF: UserDefinedFunction = functions.udf(
(ids: String) => {
val set = new util.HashSet[String]()
if (null != ids) {
ids.split(",").foreach(e => {
if (StringUtils.isNotBlank(e)) {
set.add(e)
}
})
}
set.size()
}
)
val avgScoreUDF: UserDefinedFunction = functions.udf(
(language: Double, math: Double, english: Double) => {
((language + math + english) / 3.0).formatted("%.2f").toDouble
}
)
/**
* x-y-z,经过指定的分隔符分隔后的第一项替换为指定的char
*/
val replaceFirst: UserDefinedFunction = functions.udf(
(str: String, split: String, expect: String) => {
val first = str.split(split)(0)
str.replace(first, expect)
}
)
}
Spark使用UDF基于某些列的计算
该方案使用udf用于对DataFrame的某些列进行组合计算映射出一个新的列,这种方案也就简化了map操作
val monthMapFunction: UserDefinedFunction = spark.udf.register("monthMapFunction", (cost_month: String,default:String) => {
//202305
if (StringUtils.isNotBlank(cost_month)) {
val year = cost_month.substring(0, 4)
val month = cost_month.substring(4)
s"$year-$month-01"
} else {
default
}
})
//加载注册的函数
udf_manage.monthMapFunction
val f2 = f1.withColumn("dMonth2", org.apache.spark.sql.functions.callUDF("monthMapFunction", lit("202305"),
lit("1970-01-01")))
UDF使用原则
//加一列,对参数dt的处理逻辑简单,自己处理
.withColumn("year", lit(dt.split(" ")(0).split("-")(0)))
//加一列,对参数dt的处理逻辑麻烦,把参数交给udf并封装过程
.withColumn("quarter", udf_manage.getQuarterMapFunction(lit(dt)))
UDF和Map函数的使用原则
当有多个列需要处理,并且处理的逻辑并不简单,则用map配合样例类,一次性处理
Hive自定义函数
import org.apache.commons.lang.StringUtils;
import org.apache.hadoop.hive.ql.exec.UDF;
import java.util.Arrays;
import java.util.HashSet;
public class StringDistinct extends UDF {
public static void main(String[] args) {
System.out.println(new StringDistinct().evaluate("a,b,a,b,c,b,c"));
}
/**
* @param s=a,b,a,b,c,b,c
* @return a, b, c
*/
public String evaluate(final String s) {
if (StringUtils.isEmpty(s)) {
return "";
}
String s1 = new HashSet<>(Arrays.asList(s.split(","))).toString();
return s1.substring(1, s1.length() - 1).replace(", ", ",");
}
}
在hive2的节点加载jar包
add jar /mnt/db_file/jars/udf-1.0-SNAPSHOT.jar;
create temporary function idsCnts as “com.mingzhi.StringDistinctCnts”;
SELECT * from dwd_order_info_abi WHERE dt BETWEEN ‘2023-07-01’ AND ‘2023-07-31’ AND institutionid=‘481’ AND idsCnts(send_user_ids)>1;