pyspark DecisionTreeModel不能在RDD上直接使用

训练了一个DecisionTreeModel ,然后在RDD 上准备进行验证:


dtModel     = DecisionTree.trainClassifier(data, 2, {}, impurity="entropy", maxDepth=maxTreeDepth)

predictions = dtModel.predict(data.map(lambda lp: lp.features))


def GetDtLabel(x):
    return 1 if dtModel.predict(x.features) > 0.5 else 0


dtTotalCorrect = data.map(lambda point : 1 if  GetDtLabel(point) == point.label else 0).sum()
</pre><pre class="python" name="code" snippet_file_name="blog_20160713_7_9485148" code_snippet_id="1760790">     
     提示错误:
     Exception: It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.

     看scala的代码没问题,以为是dtModel需要广播一下,但是错误依旧:
      
         dtModelBroadcast = sc.broadcast(dtModel)

          最后根据下面stackoverflow提到的才发现是pyspark的问题:

 
     http://stackoverflow.com/questions/31684842/how-to-use-java-scala-function-from-an-action-or-a-transformation
 
     http://stackoverflow.com/questions/36838024/combining-spark-streaming-mllib
  
    pyspark里面 DescitionTreeModel的predict方法源代码提到
      “In Python, predict cannot currently be used within an RDD transformation or action. 
Call predict directly on the RDD instead.”
</pre><pre code_snippet_id="1760790" snippet_file_name="blog_20160713_14_3815520" name="code" class="python" style="color: rgb(36, 39, 41); font-size: 15px; line-height: 19.5px;"> def predict(self, x):
        """
        Predict the label of one or more examples.

        Note: In Python, predict cannot currently be used within an RDD
              transformation or action.
              Call predict directly on the RDD instead.

        :param x:  Data point (feature vector),
                   or an RDD of data points (feature vectors).
        """
        if isinstance(x, RDD):
            return self.call("predict", x.map(_convert_to_vector))

        else:
            return self.call("predict", _convert_to_vector(x))</span>

        这个call是调用了self._sc方法,导致了model依赖sc

class JavaModelWrapper(object):
    """
    Wrapper for the model in JVM
    """
    def __init__(self, java_model):
        self._sc = SparkContext.getOrCreate()
        self._java_model = java_model

    def __del__(self):
        self._sc._gateway.detach(self._java_model)

    def call(self, name, *a):
        """Call method of java_model"""
        return callJavaFunc(self._sc, getattr(self._java_model, name), *a)

      原因是这里通过py4j来调用java_model( "org.apache.spark.mllib.tree.model.DecisionTreeModel"),导致了依赖SparkContext。
 
 

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值