pyspark中dataframe可以使用很多sql型的函数,比如group by、agg等,函数中经常需要调用自定义的udf函数。
以下面的udf为例,首先定义函数,函数的功能是计算分位数,
传入的col_collect_list是一个数组,由dataframe的sql函数collect_list(col)得到,传入的num是分位数的档,比如95分位就传入95;
from pyspark.sql import functions as F
from pyspark.sql import types as T
# 定义udf函数,输入一个数组,返回一个double类型的分位数结果,这里返回值是numpy类型的,要改成python的double类型
def quantie_cal(col_collect_list, num):
quantie_result = np.percentile(col_collect_list, num)
quantie_result = round(quantie_result, 2)
return float(quantie_result)
# udf的返回值需要有类型,这里是浮点数类型
udf_quantie_cal = F.udf(quantie_cal, T.DoubleType())
# 在df里进行聚合操作
valid_dwd_df = valid_dwd_df.groupBy(["routing_type", "hdmap_district", "task_purpose"]).agg(
udf_quantie_cal(collect_list("each_order_distance"), lit(95)).alias("each_order_distance_95_per"),
udf_quantie_cal(collect_list("each_order_distance"), lit(99)).alias("each_order_distance_99_per")
)
valid_dwd_df = valid_dwd_df.withColumn("dt", lit(start_time_pre))
定义好函数后,需要注册成udf函数,F.udf()前面是函数名,后面是返回的数据类型
最后在dataframe里,对三个字段进行group by聚合,然后agg里只能用udf函数对聚合字段进行处理,不能在agg()里直接用np.percentile()函数来算分位数,
udf里传入的参数必须是dataframe的列:
collect_list(“col”)是将某一列聚合成一个不去重的数组;
95不能直接传值,因为udf里的参数必须是一列,所以这里用lit(95)处理,将数值95做成一列,值都是95。
最后udf的返回值必须包裹一层返回python的数据类型,否则会报错expected zero arguments for construction of ClassDict (for numpy.dtype);这是因为udf里返回值类型都是numpy的数据类型,需要转换成python的数据类型才可以。
需要注意的是,alias()起别名后在dataframe会增加一个新列,和withColumn()效果一致,最后dataframe等于是在原有定义的基础上新增了3列each_order_95_per、each_order_99_per、dt;给下游使用时可以直接访问这三列,不需要再额外withColumn()