scikit-learn issues - classification metrics can‘t handle a mix of continuous-multioutput ...

classification metrics can’t handle a mix of continuous-multioutput and multi-label-indicator targets
  • 示例场景:

    cm = confusion_matrix(y_true, y_pred)
    f1_score(y_true, y_pred, average="macro")
    
  • 问题描述:
    这个问题常见于评价多分类任务,由于sklearn的classification metrics只接受binary的targets,所以你需要确保y_ture和y_pred中的元素都是0或1的array。

  • 策略:

  1. 检查y_true是不是Binary Label;如果不是,用LB或者MLB做好转换;
  2. 检查y_pred是不是二值的;一般来说y_pred更容易造成这个问题,因为很多人会在昨晚predict之后,直接拿网络的输出(Softmax/Sigmoid 输出)做为Metrics的输入。如果属于这种情况,对于Multi-class Classification(即y_pred是Softmax的输出),可以直接用np.argmax()做转换;而对于Multi-label Classification(即y_pred是sigmoid的输出),一般使用np.around,四舍五入将y_pred调整为二值化形式( 即相当于设置了阈值0.5)。
    # Multi-class
    y_pred = np.argmax(y_pred)
    # Multi-label
    y_pred = np.around(y_pred) # np.around
    
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值