pyspark 对多列类别特征编码 Pipeline(stages=[ StringIndexer

from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, StringIndexerModel
from pyspark.sql import SparkSession
import safe_config

spark_app_name = 'lgb_hive_data'
spark = SparkSession.builder \
    .config('spark.executor.memory', '13g') \
    .config('spark.executor.cores', '3') \
    .config('spark.driver.memory', '20g') \
    .config('spark.executor.instances', '70') \
    .config('spark.sql.execution.arrow.enabled', 'true') \
    .config('spark.driver.maxResultSize', '20g') \
    .config('spark.default.parallelism', '9000') \
    .config('spark.sql.sources.default', 'orc') \
    .config('spark.sql.sources.partitionOverwriteMode', 'dynamic') \
    .config('spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation', 'true') \
    .appName(spark_app_name) \
    .enableHiveSupport().getOrCreate()

df = spark.sql("""
select * from aiplatform.travel_risk_index_safe_rule where pt >= '20200801' 
""")
pipeline = Pipeline(stages=[
    StringIndexer(inputCol=c, outputCol='{}_new_col'.format(c),handleInvalid="keep")
    for c in safe_config.TEXT_CATEGORICAL_COLS
])
model = pipeline.fit(df)
indexed = model.transform(df)

index_dict = {c.name: c.metadata["ml_attr"]["vals"]
 for c in indexed.schema.fields if c.name.endswith("_new_col")}

# label encoder dict
import json
import numpy as np
def key_to_json(data):
    if data is None or isinstance(data, (bool, int, str, float)):
        return data
    if isinstance(data, (tuple, frozenset)):
        return str(data)
    if isinstance(data, np.integer):
        return int(data)
    if isinstance(data, np.float):
        return int(data)
    raise TypeError
def to_json(data):
    if data is None or isinstance(data, (bool, int, tuple, range, str, list)):
        return data
    if isinstance(data, (set, frozenset)):
        return sorted(data)
    if isinstance(data, np.float):
        return float(data)
    if isinstance(data, dict):
        return {key_to_json(key): to_json(data[key]) for key in data}
    raise TypeError
    
text_index_dict = {}
for index,value in enumerate(index_dict):
        print(index,value)
        col_values = index_dict[value]
        tmp_dict = {}
        for index_2,value_2 in enumerate(col_values):
            tmp_dict[value_2] = index_2
        text_index_dict[value] = tmp_dict
with open(f'''./index.json''', 'w') as fp:
    json.dump(to_json(text_index_dict), fp)

参考

https://stackoverflow.com/questions/45885044/getting-labels-from-stringindexer-stages-within-pipeline-in-spark-pyspark

pyspark特征工程常用方法(一)

https://blog.csdn.net/Katherine_hsr/article/details/81004708

 

 

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值