【pytorch】StatScores的原理与使用

Confusion matrix (混淆矩阵)

在介绍StatScores之前,我们先复习以下Confusion matrix。

我们有两组数据,分别为真实分布预测分布
预测为真定义为Possitive,预测为假定义为Negetive

四分类定义

  1. 如果预测Possitive与真实一致,则为True Possitive,简写为TP
  2. 如果预测Negetive与真实一致,则为True Negetive,简写为TN
  3. 如果预测Possitive与真实不一致,则为False Possitive,简写为FP
  4. 如果预测Negetive与真实不一致,则为False Negetive,简写为FN

关系图

关系图
StatScores类实际上就是统计一组预测数据的这四个分类。

额外提一下Precision与Recall

Precision(准确率) 与 Recall (召回率)

P r e c i s i o n = T P T P + F P Precision = \cfrac {TP} {TP+FP} Precision=TP+FPTP
R e c a l l = T P T P + F N Recall = \cfrac {TP} {TP+FN} Recall=TP+FNTP


StatScores类

继承关系

直接继承与Metrics

class StatScores(Metric)

四类任务

它将处理的case分为了四类

  1. Binary 二分类
  2. MultiClass 多分类
  3. MultiLabel 多标签
  4. MultiClass&MultiLabel

没有入参指定所属的任务case,代码中是根据pred张量来判断的。逻辑如下,
分类图

因为笔者暂时只使用第1和2中,所以其他暂不介绍了。

Update与Compute方法

所有继承Metrics的子类都需要实现Update和Compute方法。

1. update

update
update方法中调用内部方法 _stat_scores_update
_stat_scores_update
在该方法内部,首先将根据输入的数据做分类 _input_format_classification

该方法主要作用是将preds和target做one hot化,所属分类任务的case也在该方法中识别的。

_input_format_classification的四个参数

这里有三个参数注意以下:

  • threshold
    它仅仅作用与Binary的任务,作用是preds张量中,如果元素大于threshold,则规整为1,否则规整为0
  • num_classes
    指明分类种类,如果不指明的话,代码中根据元素值的最大值来判断。这个值同时也会影响one_hot后的数据长度。
  • multiclass
    如果multiclass=False,则强制认为所属任务为Binary。True或者不设置(None)则根据入参自行判断
  • topk
    在多分类任务中,在做one_hot转换时,需要返回的最大前k个位置。
    比如[0.1,0.5,0.4], 在topk=1(默认时),返回的是 [0,1,0],
    如果topk=2,则返回的是[0,1,1]

_stat_scores

_stat_scores是真实计算tp, fp, tn, fn四个值的地方。

_stat_scores

举个例子

假设我们有如下

preds  = torch.tensor([0, 1, 0])
target = torch.tensor([1, 1, 0])

首先,在 _input_format_classification方法处理后,这两个张量会转换为one_hot形式如下,

preds = [[1,0], [0,1], [1,0]]
target= [[0,1], [0,1], [1,0]]

然后, 进入**_stat_scores**
第64,65行的计算结果如下:

# 预测true是正确的预测值和预测是false是正确的预测值
true_pred, false_pred = [[False,False], [True, True], [True, True]] , 
                           [ [True True], [False, False] [False, False]
# 预测是Ture的预测值与预测是False的预测值
pos_pred, neg_pred = [[False, True] [False, True] [True, False]] , 
                           [[True False] [True False] [True False]]

这两者再两两相乘,得到tp fp tn fn

    tp = (true_pred * pos_pred).sum(dim=dim)
    fp = (false_pred * pos_pred).sum(dim=dim)

    tn = (true_pred * neg_pred).sum(dim=dim)
    fn = (false_pred * neg_pred).sum(dim=dim)

2. compute

compute调用内部方法 _stat_scores_compute
compute

_stat_scores_compute

该方法返回一个数组, [tp, fp, tn, fn, tp_fn]
**_stat_scores_compute**
这个就是StatScores的返回结果。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值