一文解读pandas_udf

1.函数定义

pyspark.sql.functions.pandas_udf(f=None, returnType=None, functionType=None)
Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and Pandas to work with the data, which allows vectorized operations
使用spark arrow传输数据,由pandas处理数据,由于使用pandas,所以可以进行一些向量化处理

参数解读

  • f: user-defined function 用户自定义函数
  • returnType:the return type of the user-defined function 用户自定义函数输出类型
  • functionType:
type说明备注
SCALAR单独处理 DataFrame 的每个元素。它采用一个或多个 pandas Series 作为输入,并返回一个 pandas Series。这种类型的 Pandas UDF 应用于 DataFrame 的 select 和 withColumn 方法。适用于 element-wise 操作default
SCALAR_ITER类似于 SCALAR,但它是在迭代器上操作的,允许更有效地处理大型数据集-
GROUPED_MAP用于分组操作,需要返回与输入 DataFrame 相同大小的 DataFrame。应用于 DataFrame 的 groupBy 和 apply 方法。适用于分组转换操作-
GROUPED_AGG用于分组聚合操作,将一组值减少为一个标量值。应用于 DataFrame 的 groupBy 和 agg 方法。适用于分组聚合操作还有一个和MAP的显著区别是,这个只支持一列作为输入,所以无法将整个pdf输入到UDF函数里

2.code show

2.1 SCALER 操作

from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, PandasUDFType

spark = SparkSession.builder \
    .appName("pandas_udf scaler Example") \
    .getOrCreate()

# 创建虚拟数据集
g = np.tile(['group a','group b'], 10)
x = np.linspace(0, 10., 20)
np.random.seed(3) # set seed for reproducibility
y_lin = 2*x + np.random.rand(len(x))/10.
y_qua = 3*x**2 + np.random.rand(len(x))
df = pd.DataFrame({'group': g, 'x': x, 'y_lin': y_lin, 'y_qua': y_qua})
schema = StructType([
    StructField('group', StringType(), nullable=False),
    StructField('x', DoubleType(), nullable=False),
    StructField('y_lin', DoubleType(), nullable=False),
    StructField('y_qua', DoubleType(), nullable=False),
])
df = spark.createDataFrame(df, schema=schema)
+-------+------------------+-------------------+-------------------+
|  group|                 x|              y_lin|              y_qua|
+-------+------------------+-------------------+-------------------+
|group a|               0.0|0.05507979025745755|0.28352508177131874|
|group b|0.5263157894736842|  1.123446361209179| 1.5241628490609185|
|group a|1.0526315789473684|  2.134353631786031| 3.7645534406624286|
|group b|1.5789473684210527| 3.2089774973618717| 7.6360921152062655|
|group a|2.1052631578947367|  4.299821011224239|   13.8410479099986|
|group b| 2.631578947368421|  5.352787203630186| 21.555938033209422|
|group a|3.1578947368421053|  6.328348004730595|  30.22326103930139|
+-------+------------------+-------------------+-------------------+

# 对一列进行操作
# series to series pandas UDF
@F.pandas_udf(DoubleType())
def standardise(col1: pd.Series) -> pd.Series:
    return (col1 - col1.mean())/col1.std()
res = df.select(standardise(F.col('y_lin')).alias('result'))
res.show(5)
+-------------------+
|             result|
+-------------------+
|-1.6054255151193093|
|-1.4337009540623533|
|-1.2712121491623172|
| -1.098481817986802|
|-0.9231444116198374|
+-------------------+

def standardise(col1: pd.Series) -> pd.Series:
    return (col1 - col1.mean())/col1.std()

standard_udf = pandas_udf(standardise, DoubleType())
df = df.withColumn("y_lin_standard", standard_udf(F.col('y_lin')))
df.show(3)
+-------+------------------+-------------------+-------------------+-------------------+
|  group|                 x|              y_lin|              y_qua|     y_lin_standard|
+-------+------------------+-------------------+-------------------+-------------------+
|group a|               0.0|0.05507979025745755|0.28352508177131874|-1.6054255151193093|
|group b|0.5263157894736842|  1.123446361209179| 1.5241628490609185|-1.4337009540623533|
|group a|1.0526315789473684|  2.134353631786031| 3.7645534406624286|-1.2712121491623172|
+-------+------------------+-------------------+-------------------+-------------------+

def standardise(col1: pd.Series, col2: pd.Series) -> pd.Series:
    return (col1 - col2.mean())/col1.std()

standard_udf = pandas_udf(standardise, DoubleType())
df = df.withColumn("ret", standard_udf(F.col('y_lin'), F.col('y_qua')))
df.show(3)
+-------+------------------+-------------------+-------------------+-------------------+-------------------+
|  group|                 x|              y_lin|              y_qua|     y_lin_standard|                ret|
+-------+------------------+-------------------+-------------------+-------------------+-------------------+
|group a|               0.0|0.05507979025745755|0.28352508177131874|-1.6054255151193093| -16.57141348616838|
|group b|0.5263157894736842|  1.123446361209179| 1.5241628490609185|-1.4337009540623533|-16.399688925111427|
|group a|1.0526315789473684|  2.134353631786031| 3.7645534406624286|-1.2712121491623172| -16.23720012021139|
+-------+------------------+-------------------+-------------------+-------------------+-------------------+

# 官方
@pandas_udf("col1 string, col2 long")
def func(s1: pd.Series, s2: pd.Series, s3: pd.DataFrame) -> pd.DataFrame:
    s3['col2'] = s1 + s2.str.len()
    return s3

# Create a Spark DataFrame that has three columns including a struct column.
df = spark.createDataFrame(
    [[1, "a string", ("a nested string",)]],
    "long_col long, string_col string, struct_col struct<col1:string>")
df.show()
+--------+----------+-----------------+
|long_col|string_col|       struct_col|
+--------+----------+-----------------+
|       1|  a string|{a nested string}|
+--------+----------+-----------------+
df.select(func("long_col", "string_col", "struct_col").alias('ret')).show()
+--------------------+
|                 ret|
+--------------------+
|{a nested string, 9}|
+--------------------+

# 输出dataframe
@pandas_udf("first string, last string")
def split_expand(s: pd.Series) -> pd.DataFrame:
    return s.str.split(expand=True)
df = spark.createDataFrame([("John Doe",)], ("name",))
df.select(split_expand("name")).show()

2.2 SCALER_ITER

from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import StringType, LongType
import pandas as pd
import re

spark = SparkSession.builder \
    .appName("Pandas UDF Example") \
    .getOrCreate()

def extract_numbers(series: pd.Series) -> pd.Series:
    return series.apply(lambda x: int(re.sub(r'\D', '', x)) if re.sub(r'\D', '', x) else None)

@pandas_udf(LongType(), PandasUDFType.SCALAR_ITER)
def extract_numbers_udf(series_iter):
    for series in series_iter:
        yield extract_numbers(series)

        
data = [("abc123",), ("def456",), ("ghi789",), ("jkl0",)]
schema = "text STRING"
input_df = spark.createDataFrame(data, schema=schema)
result_df = input_df.select(extract_numbers_udf("text").alias("numbers"))
result_df.show()
+-------+
|numbers|
+-------+
|    123|
|    456|
|    789|
|      0|
+-------+

2.3 GROUP_MAP

@pandas_udf(df.schema, PandasUDFType.GROUPED_MAP)
# Input/output are both a pandas.DataFrame
def subtract_mean(pdf):
    return pdf.assign(v=pdf.v - pdf.v.mean())

df.groupby('id').apply(subtract_mean)

2.4 GROUP_AGG

# stype1
@F.pandas_udf(T.DoubleType())
def average_column(col1: pd.Series, col2: pd.Series) -> float:
    return (col1 + col2).mean()
res = df.groupby('group').agg(average_column(F.col('y_lin'), F.col('y_qua')).alias('average of y_lin + y_qua'))

# stype2
def average_(col1: pd.Series, col2: pd.Series) -> float:
    return (col1 + col2).mean()
average_column = pandas_udf(average_, DoubleType(), PandasUDFType.GROUPED_AGG)
res = df.groupby('group').agg(average_column(F.col('y_lin'), F.col('y_qua')).alias('average of y_lin + y_qua'))

show_frame(res)
# +-------+------------------------+
# |group  |average of y_lin + y_qua|
# +-------+------------------------+
# |group a|104.770                 |
# |group b|121.621                 |
# +-------+------------------------+


3.使用限制以及解决方案:

使用限制

  1. 自定义函数不接受额外的参数
  2. 不接受conditional expressions(a > 1)或者Short-circuiting(ex:a==b)
  3. pyspark.sql.types.ArrayType的pyspark.sql.types.TimestampType和嵌套的pyspark.sql.types.StructType目前不支持作为输出类型
  4. (错误)当函数类型是GROUPED_AGG时,只支持一列作为输入,所以无法将整个pdf输入到UDF函数里; 这个限制后面来看是不成立的,所以增加标识

限制1,自定义中需要传入函数,可以通过python的装饰器函数解决

def sum_pd(pdf):
    v = pdf.v
    return pdf.assign(c=v.sum())
sum_udf = pandas_udf(sum_pd, "id long, v double, c double", PandasUDFType.GROUPED_MAP)
df.groupby("id").apply(sum_udf).show()

+---+----+----+
| id|   v|   c|
+---+----+----+
|  1| 1.0| 3.0|
|  1| 2.0| 3.0|
|  2| 3.0|18.0|
|  2| 5.0|18.0|
|  2|10.0|18.0|
+---+----+----+

# 增加参数的例子
def sum_pd(pp):
    def wrap(pdf):
        v = pdf.v
        return pdf.assign(c=v.sum() + pp)
    return wrap

pp = 1
sum_p = sum_pd(pp)
sum_udf = pandas_udf(sum_p, "id long, v double, c double", PandasUDFType.GROUPED_MAP)
df.groupby("id").apply(sum_udf).show()

+---+----+----+
| id|   v|   c|
+---+----+----+
|  1| 1.0| 4.0|
|  1| 2.0| 4.0|
|  2| 3.0|19.0|
|  2| 5.0|19.0|
|  2|10.0|19.0|
+---+----+----+

对于限制4,首先需要声明的是是否支持多列输入是取决于函数本身,在我开始的例子中,由于入参是pdf,所以无法支持多列,此中情况下,可以引入StructType解决,将需要输入的列整合到struct中输入到UDF函数中;
当入参设定的就是多列时,是支持多列的,但是为了代码的简洁性,个人更加倾向于第一种写法

from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, PandasUDFType, struct
import pandas as pd

spark = SparkSession.builder \
    .appName("Two Day Subtract Example") \
    .getOrCreate()

data = [
    (1, "2023-04-12", 2.0),
    (1, "2023-04-11", 4.0),
    (2, "2023-04-12", 3.0),
    (2, "2023-04-11", 5.0)
]

columns = ["id", "dt", "v"]

df = spark.createDataFrame(data, columns)

def two_day_subtract(window_end, datapipe):
        def wrap(pdf):
            assert 0 < pdf.shape[0] <= 2
            dt = pdf[0]
            if pdf.shape[0] == 1:
                return pdf[0][datapipe.stat_col] * (1 if dt == window_end else -1)
            else:
                return (pdf[0][datapipe.stat_col] - pdf[1][datapipe.stat_col]) * (1 if dt == window_end else -1)

        return wrap

class dd:
    def __init__(self, stat_col, dt):
        self.stat_col = stat_col
        self.dt = dt

window_end = '2023-04-12'
datapipe = dd(stat_col='v', dt='dt')
idd = 'id'

substract_udf = pandas_udf("double", PandasUDFType.GROUPED_AGG)(two_day_subtract(window_end, datapipe))
stat_df = df.groupby(idd).agg(substract_udf(struct(df['v'], df['dt'])).alias("num"))
stat_df.show()

# 装饰器写法也可以
@pandas_udf(DoubleType())
def two_day_subtract(window_end, datapipe):
        def wrap(pdf):
            assert 0 < pdf.shape[0] <= 2
            dt = pdf[0]
            if pdf.shape[0] == 1:
                return pdf[0][datapipe.stat_col] * (1 if dt == window_end else -1)
            else:
                return (pdf[0][datapipe.stat_col] - pdf[1][datapipe.stat_col]) * (1 if dt == window_end else -1)

        return wrap
stat_df = df.groupby(idd).agg(substract_udf(struct(df['v'], df['dt'])).alias("num"))
stat_df.show()

# 直接多列输入
def two_day_subtract(window_end, datapipe):
    def wrap(s1, s2):
            assert 0 < len(s1) <= 2
            dt = s2[0]
            if len(s1) == 1:
                return s1 * (1 if dt == window_end else -1)
            else:
                return (s1[0] - s1[1]) * (1 if dt == window_end else -1)

    return wrap

substract_udf = pandas_udf("double", PandasUDFType.GROUPED_AGG)(two_day_subtract(window_end, datapipe))
stat_df = df.groupby(idd).agg(substract_udf(F.col('v'), F.col('dt')).alias("num"))
stat_df.show()

reference:
1. Documents
2. Pandas UDFs in PySpark
3. Blog on databricks

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值