StatScores的原理与使用
Confusion matrix (混淆矩阵)
在介绍StatScores之前,我们先复习以下Confusion matrix。
我们有两组数据,分别为真实分布,预测分布
预测为真定义为Possitive,预测为假定义为Negetive
四分类定义
- 如果预测Possitive与真实一致,则为True Possitive,简写为TP
- 如果预测Negetive与真实一致,则为True Negetive,简写为TN
- 如果预测Possitive与真实不一致,则为False Possitive,简写为FP
- 如果预测Negetive与真实不一致,则为False Negetive,简写为FN
关系图
StatScores类实际上就是统计一组预测数据的这四个分类。
额外提一下Precision与Recall
Precision(准确率) 与 Recall (召回率)
P
r
e
c
i
s
i
o
n
=
T
P
T
P
+
F
P
Precision = \cfrac {TP} {TP+FP}
Precision=TP+FPTP
R
e
c
a
l
l
=
T
P
T
P
+
F
N
Recall = \cfrac {TP} {TP+FN}
Recall=TP+FNTP
StatScores类
继承关系
直接继承与Metrics
class StatScores(Metric)
四类任务
它将处理的case分为了四类
- Binary 二分类
- MultiClass 多分类
- MultiLabel 多标签
- MultiClass&MultiLabel
没有入参指定所属的任务case,代码中是根据pred张量来判断的。逻辑如下,
因为笔者暂时只使用第1和2中,所以其他暂不介绍了。
Update与Compute方法
所有继承Metrics的子类都需要实现Update和Compute方法。
1. update
update方法中调用内部方法 _stat_scores_update
在该方法内部,首先将根据输入的数据做分类 _input_format_classification
该方法主要作用是将preds和target做one hot化,所属分类任务的case也在该方法中识别的。
_input_format_classification的四个参数
这里有三个参数注意以下:
- threshold
它仅仅作用与Binary的任务,作用是preds张量中,如果元素大于threshold,则规整为1,否则规整为0 - num_classes
指明分类种类,如果不指明的话,代码中根据元素值的最大值来判断。这个值同时也会影响one_hot后的数据长度。 - multiclass
如果multiclass=False,则强制认为所属任务为Binary。True或者不设置(None)则根据入参自行判断 - topk
在多分类任务中,在做one_hot转换时,需要返回的最大前k个位置。
比如[0.1,0.5,0.4], 在topk=1(默认时),返回的是 [0,1,0],
如果topk=2,则返回的是[0,1,1]
_stat_scores
_stat_scores是真实计算tp, fp, tn, fn四个值的地方。
举个例子
假设我们有如下
preds = torch.tensor([0, 1, 0])
target = torch.tensor([1, 1, 0])
首先,在 _input_format_classification方法处理后,这两个张量会转换为one_hot形式如下,
preds = [[1,0], [0,1], [1,0]]
target= [[0,1], [0,1], [1,0]]
然后, 进入**_stat_scores**
第64,65行的计算结果如下:
# 预测true是正确的预测值和预测是false是正确的预测值
true_pred, false_pred = [[False,False], [True, True], [True, True]] ,
[ [True True], [False, False] [False, False]
# 预测是Ture的预测值与预测是False的预测值
pos_pred, neg_pred = [[False, True] [False, True] [True, False]] ,
[[True False] [True False] [True False]]
这两者再两两相乘,得到tp fp tn fn
tp = (true_pred * pos_pred).sum(dim=dim)
fp = (false_pred * pos_pred).sum(dim=dim)
tn = (true_pred * neg_pred).sum(dim=dim)
fn = (false_pred * neg_pred).sum(dim=dim)
2. compute
compute调用内部方法 _stat_scores_compute
_stat_scores_compute
该方法返回一个数组, [tp, fp, tn, fn, tp_fn]
这个就是StatScores的返回结果。