from pyspark.mllib.evaluation import BinaryClassificationMetrics

# Scala version implements .roc() and .pr()
# Python:
# Scala:
class CurveMetrics(BinaryClassificationMetrics):
    def __init__(self, *args):
        super(CurveMetrics, self).__init__(*args)

    def _to_list(self, rdd):
        points = []
        # Note this collect could be inefficient for large datasets
        # considering there may be one probability per datapoint (at most)
        # The Scala version takes a numBins parameter,
        # but it doesn't seem possible to pass this from Python to Java
        for row in rdd.collect():
            # Results are returned as type scala.Tuple2,
            # which doesn't appear to have a py4j mapping
            points += [(float(row._1()), float(row._2()))]
        return points

    def get_curve(self, method):
        rdd = getattr(self._java_model, method)().toJavaRDD()
        return self._to_list(rdd)


preds =,'prediction') row: (float(row['prediction']), float(row[label]))) #将label和预测值转化为所需的rdd格式
points = CurveMetrics(preds).get_curve('roc')
fpr = [x[0] for x in points]
tpr = [x[1] for x in points]
auc_roc = BinaryClassificationMetrics(preds).areaUnderROC
result['figure'] = {'title': 'ROC曲线',
                    'AUROC': auc_roc,
                    'x': [1-f for f in fpr],
                    'y': tpr,
                    'xlabel': '特异度',
                    'ylabel': '灵敏度'}
已标记关键词 清除标记
©️2020 CSDN 皮肤主题: 数字20 设计师:CSDN官方博客 返回首页