报错记录:TypeError: classification_report() takes 2 positional arguments but 3 were given

问题描述

  • 今天在使用sklearn_crfsuite.metrics.flat_classification_report函数的时候突然报错:TypeError: classification_report() takes 2 positional arguments but 3 were given,这里对该函数进行了详细剖析,找到报错原因,并给出解决办法。
    在这里插入图片描述

函数详细剖析

  • 使用sklearn_crfsuite拓展包中的metrics查看指标
from sklearn_crfsuite import metrics
# y_true为真实标签、y_pred为预测标签、labels为想要查看指标的标签(通常去除'O')
print(metrics.flat_classification_report(y_true, y_pred, labels=sort_labels))
  • 输入y_true和y_pred的输入格式通常如下图,即长度为序列个数,每个子列表为序列中每个字符的类别
    在这里插入图片描述
  • 跳转到sklearn_crfsuite.metrics.flat_classification_report函数中发现,实际调用的还是sklearn.metrics.classification_report,并使用@_flattens_y装饰器对输入数据进行展平后输入
    在这里插入图片描述
    在这里插入图片描述
  • 展平后的数据形式如下图
    在这里插入图片描述
  • 从flat_classification_report中可以看到,它将label坐标作为位置参数传入,即未指定参数名称
  • 但在下图的sklearn.metrics.classification_report函数中可以看到,其只接收y_true和y_pred两个位置参数,第三个*代表后面的参数必须指定参数名称,从而导致传入的labels成了多余的参数而报错
    在这里插入图片描述

解决办法

  • 既然都是调用的sklearn中的方法,那就自己重新实现以下过程,即展平后输入
from sklearn import metrics

y_true = [label for y in y_true for label in y]
y_pred = [label for y in y_pred for label in y]

print(metrics.classification_report(
    y_true, y_pred, labels=sort_labels
))
  • 或者使用sklearn_crfsuite里的方式进行展平
from sklearn import metrics
from itertools import chain

y_true = list(chain.from_iterable(y_true))
y_pred = list(chain.from_iterable(y_pred))

print(metrics.classification_report(
    y_true, y_pred, labels=sort_labels
))
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值