简单应用
from pyspark. sql import *
from pyspark. sql. functions import *
from pyspark. sql. types import *
mission = "xxx"
spark = SparkSession. builder. appName( mission) . enableHiveSupport( ) . getOrCreate( )
l = [ ( "A" , 16 ) , ( "B" , 21 ) , ( "B" , 14 ) , ( "B" , 18 ) ]
df = spark. createDataFrame( l, [ "name" , "age" ] )
def plus_one ( a) :
return a + 1
plus_one_udf = udf( plus_one, returnType= LongType( ) )
df. withColumn( "one_processed" , plus_one_udf( df[ "age" ] ) ) . show( )
@udf ( returnType= LongType( ) )
def plus_ten ( a) :
return a + 10
df. withColumn( "one_processed" , plus_ten( df[ "age" ] ) ) . show( )
from pyspark. sql. functions import pandas_udf, PandasUDFType
@pandas_udf ( 'long' )
def plus_hunderd ( a) :
return a + 100
spark. udf. register( 'plus_hunderd' , plus_hunderd)
df. withColumn( "one_processed" , plus_hunderd( df[ "age" ] ) ) . show( )
@pandas_udf ( sdf. schema, PandasUDFType. GROUPED_MAP)
def subtract_mean ( df) :
return df. assign( new_col= df. three - df. three. mean( ) )
sdf. groupby( "id" ) . apply ( subtract_mean)
详见: pyspark dataframe之udf
实际场景
def get_group ( rate) :
group = ( rate* 100 ) // 10 + 1
if ( rate* 100 ) % 10 == 0 : group -= 1
if group > 10 : group = 10
return int ( group)
group_udf = udf( get_group, IntegerType( ) )
df = df1. join( df2, [ 'region' , 'id' ] , 'left' ) \
. . .
. withColumn( 'group' , group_udf( col( 'rate' ) ) )