pySpark提取ROC曲线

Scala中BinaryClassificationMetrics函数提供了提取ROC曲线的方式,但pyspark中没有提供,因此需要先从Scala模块中借用

from pyspark.mllib.evaluation import BinaryClassificationMetrics

# Scala version implements .roc() and .pr()
# Python: https://spark.apache.org/docs/latest/api/python/_modules/pyspark/mllib/common.html
# Scala: https://spark.apache.org/docs/latest/api/java/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.html
class CurveMetrics(BinaryClassificationMetrics):
    def __init__(self, *args):
        super(CurveMetrics, self).__init__(*args)

    def _to_list(self, rdd):
        points = []
        # Note this collect could be inefficient for large datasets
        # considering there may be one probability per datapoint (at most)
        # The Scala version takes a numBins parameter,
        # but it doesn't seem possible to pass this from Python to Java
        for row in rdd.collect():
            # Results are returned as type scala.Tuple2,
            # which doesn't appear to have a py4j mapping
            points += [(float(row._1()), float(row._2()))]
        return points

    def get_curve(self, method):
        rdd = getattr(self._java_model, method)().toJavaRDD()
        return self._to_list(rdd)

定义函数后具体使用如下:

preds = sdf.select(label,'prediction').rdd.map(lambda row: (float(row['prediction']), float(row[label]))) #将label和预测值转化为所需的rdd格式
points = CurveMetrics(preds).get_curve('roc')
fpr = [x[0] for x in points]
tpr = [x[1] for x in points]
auc_roc = BinaryClassificationMetrics(preds).areaUnderROC
result['figure'] = {'title': 'ROC曲线',
                    'AUROC': auc_roc,
                    'x': [1-f for f in fpr],
                    'y': tpr,
                    'xlabel': '特异度',
                    'ylabel': '灵敏度'}
已标记关键词 清除标记
相关推荐
©️2020 CSDN 皮肤主题: 数字20 设计师:CSDN官方博客 返回首页