使用Pyspark 运行lightgbm的预测函数时遇到 expected zero arguments for construction of ClassDict (for numpy.dtype)

运行Pyspark,出现:net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.dtype)

流程

使用将模型预测功能封装成udf,供spark使用,udf 填入的是各个列名,但udf使用的函数是处理成一行的数据,如

数据样例:

      col_1  col_2  col_3  col_4  col_5  col_6  col_7  col_8  col_9  col_10  \
0       1.0    1.0    1.0    1.0    1.0    1.0    1.0   1.00    1.0     1.0   
1       1.0    1.0    1.0    1.0    1.0    1.0    0.0   1.00    1.0     1.0   
2       1.0    1.0    1.0    1.0    1.0    1.0    1.0   1.00    1.0     1.0   
3       1.0    1.0    1.0    1.0    1.0    1.0    1.0   1.00    1.0     1.0   
4       1.0    1.0    1.0    1.0    1.0    1.0    0.0   1.00    1.0     1.0   
...     ...    ...    ...    ...    ...    ...    ...    ...    ...     ...   
1110    1.0    1.0    1.0    1.0    1.0    1.0    0.0   1.00    1.0     1.0   
1111    1.0    1.0    1.0    1.0    1.0    1.0    0.0   1.00    1.0     1.0   
1112    1.0    0.0    0.0    1.0    1.0    1.0    0.0   0.04    1.0     1.0   
1113    1.0    1.0    1.0    1.0    1.0    0.0    1.0   1.00    1.0     1.0   
1114    1.0    1.0    1.0    1.0    1.0    1.0    0.0   1.00    1.0     1.0   

predct_multicase 处理的是其中的一行数据:

gbm.predict()入参是个二维list,所以,x_test 对原始数据改为list。

def predict_multicase(df_columns):
    '''
     预测得分并排序
    :param x_test:
    :return:
    '''
    x_test = [df_columns]
    ypred = gbm.predict(x_test)
    return ypred[0]
udf_predict_multicase = F.udf(predict_multicase, T.DoubleType())
# df1 为pandas 格式数组
df = spark.createDataFrame(df1)
df_colums = df.columns
df2 = df.withColumn("rank_score",udf_predict_multicase(F.struct([df[col] for col in df_colums])))
df2.show()

运行报错:

Job aborted due to stage failure: Task 0 in stage 5.0 failed 4 times, most recent failure:: net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.dtype)
	at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23)
	at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:707)
	at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:175)
	at net.razorvine.pickle.Unpickler.load(Unpickler.java:99)
	at net.razorvine.pickle.Unpickler.loads(Unpickler.java:112)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$doExecute$1$$anonfun$apply$6.apply(BatchEvalPythonExec.scala:156)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$doExecute$1$$anonfun$apply$6.apply(BatchEvalPythonExec.scala:155)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:434)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:834)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:834)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:43)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:89)
	at org.apache.spark.scheduler.Task.run(Task.scala:112)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:388)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
	at java.lang.Thread.run(Thread.java:745)

Driver stacktrace:

发现主要是数据类型没有对应上,将predict的结果转成float类型,将udf中的数据返回参数改为T.FolatType

def predict_multicase(df_columns):
    '''
     预测得分并排序
    :param x_test:
    :param comments:
    :param model_input_path:
    :return:
    '''
    x_test = [df_columns]
    ypred = gbm.predict(x_test)


    return float(ypred[0])


udf_predict_multicase = F.udf(predict_multicase, T.FloatType())

 

运行结果: 

	col_1	col_2	col_3	col_4	col_5	col_6	col_7	col_8	col_9	col_10	col_11	col_12	col_13	rank_score
0	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.00	1.0	1.0	0.200000	0.0	1.0	0.018608
1	1.0	1.0	1.0	1.0	1.0	1.0	0.0	1.00	1.0	1.0	0.200000	0.0	0.0	-0.014879
2	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.00	1.0	1.0	0.166667	0.0	1.0	0.015165
3	1.0	1.0	1.0	1.0	1.0	1.0	1.0	1.00	1.0	1.0	0.166667	0.0	0.0	-0.012573
4	1.0	1.0	1.0	1.0	1.0	1.0	0.0	1.00	1.0	1.0	0.333333	0.0	0.0	-0.017949
...	...	...	...	...	...	...	...	...	...	...	...	...	...	...
1110	1.0	1.0	1.0	1.0	1.0	1.0	0.0	1.00	1.0	1.0	0.250000	0.0	0.0	-0.018355
1111	1.0	1.0	1.0	1.0	1.0	1.0	0.0	1.00	1.0	1.0	0.250000	0.0	0.0	-0.018355
1112	1.0	0.0	0.0	1.0	1.0	1.0	0.0	0.04	1.0	1.0	0.200000	0.0	1.0	0.008962
1113	1.0	1.0	1.0	1.0	1.0	0.0	1.0	1.00	1.0	1.0	0.250000	0.0	1.0	0.010648
1114	1.0	1.0	1.0	1.0	1.0	1.0	0.0	1.00	1.0	1.0	0.333333	0.0	1.0	0.009966

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

samoyan

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值