pyspark target mean encoding入门版

写了一个简单版本的target mean encoding, 代码如下:

from pyspark.sql.functions import create_map
from itertolls import chain
agg = df.select([f,target]).groupnBy(f).agg(avg_(target).alias('mean'), count_(target).alias('count'))
agg = df.withColumn("smooth", (col('count') * col('mean') + m * col('mean')) / (col('mean') + m ))

agg_data = agg.select([f, 'smooth']).collect()
map_dict = {}
for r in agg_data:
	map_dict[r[0]] = row[1]

mapping_expr = create_map([lit(x) for x in chain(*map_dict.items())])
df = df.withColumn(f'{f}_encoded', mapping_expr[df[f]])


©️2020 CSDN 皮肤主题: 技术黑板 设计师: CSDN官方博客 返回首页
实付0元
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值