Pytorch -- sensitivity 敏感度计算
1. sensitivity是一种局部性的指标,表达 正确识别正类个数 / 正类总个数
- Sensitivity/TPR = TP / (TP + FN)
2. specificity同理,不同之处为,正确识别负类个数 / 负类总个数
- Specificity/TNR = TN / (TN + FP)
def sensitivity(output, target, sensi):
'''
这里类别数为3
传入参数:
sensi = np.array([-1] * 3) (首次,后面变为sensitivity的值)
output --> tensor(80,3) 从outputs, _ = net(inputs)中获取
target --> tensor(80)
返回值:
sensitivity --> np.array
'''
_, pred = output.max(1)
pre_mask = torch.zeros(output.size()).scatter_(1, pred.cpu().view(-1, 1), 1.)
tar_mask = torch.zeros(output.size()).scatter_(1, target.data.cpu().view(-1, 1), 1.)
acc_mask = pre_mask * tar_mask
sensitivity = acc_mask.sum(0) / tar_mask.sum(0)
sensitivity = sensitivity.numpy()
if sensi[0] != -1 :
sensitivity = (sensitivity + sensi) / 2
return sensitivity
Batch_size = 80
print(output)
tensor([[-0.0082