六、测试网络模型
(1) 基本概念理解
需要清楚几个概念:准确度、精度、召回率
TP: True Positive,将正样本预测为正样本的样本数量(预测正确)
FN: False Negtive,将正样本预测为负样本的样本数量
FP: False Positive,将负样本预测为正样本的样本数量
TN: True Negtive,将负样本预测为正样本的样本数量(预测正确)
1. 准确度:准确度表示分类正确的样本数所占比例
ACC = ( TP + TN) / ( TP + TN + FP + FN)
2.精确度、精度:该概念是针对“预测结果”而言的。表示预测为正类的样本中有多少是真的正样本
P = TP / TP + FP
3.召回率:该概念是针对“原始样本”而言的。表示样本中的正例有多少被分类正确了,也即一种是把原来的正类预测成正类(TP),另一种就是把原来的正类预测为负类(FN)。
R = TP / TP + FN
(2) 测试网络模型、计算准确度
from torch.utils.data import DataLoader
import torch
from MyData import MyDataset
import torchvision.transforms as trans
from PIL import ImageDraw
import matplotlib.pyplot as plt
def test(self):
testloader = DataLoader(dataset=self.test_dataset, batch_size=50, shuffle=True)
net = torch.load("models/net.pth")
total = 0
for x, y in testloader:
# x , y = x.cuda(), y.cuda()
category, axes = net(x)
total += (category.round() == y[:,4]).sum() # 预测值等于标签的总数
index = category.round() == 1
"""
这里表示有小黄人的图片的索引集 (最后结果是True 和 False的集合)
形如:tensor([True, True, False, True, False, True, True, False, True, True])
"""
target = y[index] # 有小黄人的图片的标签(包括坐标和分类标签)
"""
还原有小黄人的图片,因为现在要可视化图片,所以要把之前对图片进行的归一化和去均值操作逆向还原回去。
数据预处理的时候对其做了标准化:处理后的图片=(原始img/255 - mean)/ std 那么现在计算原始图片,原始img =(处理后的图片 * std + mean)*255
"""
x = (x[index].cpu() * MyDataset.std.reshape(-1, 3, 1, 1) + MyDataset.mean.reshape(-1, 3, 1, 1)) # 还原预测为正样本的数据。不用乘以255.。trans.ToPILImage("RGB"):自动会乘以255
for j, i in enumerate(axes[index]): # j 为enumerate自动产生的索引
boxes = (i.data.cpu().numpy() * 224).astype(np.int32) # 还原预测坐标并将其转化为无符号整型
target_box = (target[j, 0:4].data.cpu().numpy() * 224).astype(np.int32) # 还原目标坐标并将其转化为无符号整型
img = trans.ToPILImage()(x[j]) # 转换图片
"""
torchvision.transforms.ToPILImage
对于一个Tensor的转化过程是:
1. 将张量的每个元素乘上255
2. 将张量的数据类型有FloatTensor转化成Uint8
3. 将张量转化成numpy的ndarray类型
4. 对ndarray对象做transpose (1, 2, 0)的操作
5. 利用Image下的fromarray函数,将ndarray对象转化成PILImage形式
6. 输出PILImage
"""
plt.clf()
plt.axis("off")
draw = ImageDraw.Draw(img)
draw.rectangle(boxes.tolist(), outline="red") # 预测值
draw.rectangle(target_box.tolist(), outline="yellow") # 原始值
plt.imshow(img)
plt.pause(1)
# 删除节点中的一些参数,为了节省内存空间
del boxes, target_box, img, draw
del x, y, category, axes, index, target
print("正确率:", total/len(category.round)) # GC
(3) 计算网络精度
"""
P = TP / TP + FP
TP: 预测为正样本的结果中,真正的正样本的数量
FP: 预测为正样本的结果中,不是真正的正样本的数量
1. 如何找到真正的正样本?TP
分析: 因为预测出来的正样本中,既包含了真正的正样本,也包含了假的正样本。只有在标签中才能准确的找到哪些是真的正样本,哪些是负样本。所以预测中的正样本的下标与标签中的正样本的下标取交集后就可以找到预测结果中真的正样本。
"""
# 先计算 TP +FP
TP + FP = (category.round() == 1).sum() # 预测为正样本的总数(包括真的正样本和假的正样本)
# 原始(标签中)为正样本的下标
bool_index1 = y[:, 4] == 1
# 然后找出标签中非零元素的索引, flatten()按行的方向降维,直接变成一行。
a_index = torch.nonzeros(bool_index1).flatten() # 找出了所有的1所在的位置,也就是真的正样本的索引
# 预测为正样本的下标
bool_index2 = category.round() == 1 # 预测为正样本的索引集
# 取出预测值中非零元素的索引
b_index = torch.nonzero(bool_index2).flatten()
"""
求原始1所在的位置与预测后1所在的位置的交集。所得结果就是预测为正样本中,预测值中真的正样本的位置。求len()得到的就是真的正样本的个数.
"""
TP = np.intersectld(a_index,b_index)
print(TP)
# p = TP /(TP+FP)
# 精度
P = len(TP) / (category.round() == 1).sum()
(4) 计算网络召回率
'''
召回率:表示样本中的正例有多少被分类正确
R = TP / TP + FN
例如: 总共有100个样本,80个正样本,20个负样本。但是预测的时候60个正样本,将20个正样本预测为了负样本。
将: 60/ (60 + 20)的值称为召回率
'''
# 召回率
R = len(TP) / (y[:,4] == 1).sum() # TP + FN 就表示正样本的数量