背景
我最初的问题是,为什么在地图函数中使用DecisionTreeModel.predict会引发异常?并与如何使用 MLlib 在 Spark 上生成元组(原始标签,预测标签)?有关
当我们使用 Scala API 推荐的方式并使用DecisionTreeModel获得RDD[LabeledPoint]的预测时,只需在RDD上进行映射:
val labelAndPreds = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
不幸的是,PySpark 中的类似方法不能很好地工作:
labelsAndPredictions = testData.map(
lambda lp: (lp.label, model.predict(lp.features))
labelsAndPredictions.first()
例外:似乎您正在尝试从广播变量,操作或转换中引用 SparkContext。 SparkContext 只能在驱动程序上使用,而不能在工作程序上运行的代码中使用。有关更多信息,请参见SPARK-5063。
而不是官方文件建议这样的事情:
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
那么这是怎么回事?此处没有广播变量,并且Scala API定义predict如下:
/**
* Predict values for a single data point using the model trained.
*
* @param features array representing a single data point
* @return Double prediction from the trained model
*/
def predict(features: Vector): Double = {
topNode.predict(features)
}
/**
* Predict values for the given data set using the model trained.
*
* @param features RDD representing data points to be predicted
* @return RDD of predictions for each of the given data points
*/
def predict(features: RDD[Vector]): RDD[Double] = {
features.map(x => predict(x))
}
因此至少乍一看,从动作或转换中调用就不是问题,因为预测似乎是本地操作。
说明
经过一番挖掘,我发现问题的根源是从DecisionTreeModel.predict调用的JavaModelWrapper.call方法。调用 Java 函数所需的访问 SparkContext:
callJavaFunc(self._sc, getattr(self._java_model, name), *a)
题
对于DecisionTreeModel.predict,有一个建议的解决方法,所有必需的代码已经是 Scala API 的一部分,但是一般来说,是否有任何优雅的方式来处理此类问题?
我现在能想到的只有解决方案才是重量级的:
通过隐式转换扩展 Spark 类或添加某种包装将所有内容推送到 JVM
直接使用 Py4j 网关