【新手适用】手把手教你从零开始实现一个基于Pytorch的卷积神经网络CNN三: 如何验证和测试模型

【新手适用】手把手教你从零开始实现一个基于Pytorch的卷积神经网络CNN二: 如何训练模型,内附详细损失、准确率、均值计算-CSDN博客

从零开始实现一个基于Pytorch的卷积神经网络 - 知乎 (zhihu.com)

1 初始化、导入模型和数据集

新建一个test.py文件,导入所需的包,并且定义测试数据集和dataloader。

测试和训练的不同点:

  • data:把train设为False
  • dataloader:不需要打乱数据集,设置shuffle=False,一批次只需要送入一张图像,batchsize=1
import torch
import torchvision
import torch.utils.data as Data

# 把train设为False
test_data = torchvision.datasets.MNIST(root='./data/', train=False, transform=torchvision.transforms.ToTensor(), download=False)
# 不需要打乱数据集,一批次只需要送入一张图像
test_loader = Data.DataLoader(test_data, batch_size=1, shuffle=False)

 定义所需的设备。

# 定义需要使用的设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

 2 加载模型

加载模型需要使用torch.load(f, map_location, pickle_module, pickle_load_args)函数

  • f:我们模型文件的路径和名称,在这里是'./LeNet.pkl'
  • map_location重新映射使用的设备,一般情况下这个参数不需要任何的修改,但是如果你想要把一个用GPU训练的模型放在一个只有cpu的设备上时会发生一些错误,而这时就需要定义该参数了,可以在这里填上torch.device(device)以避免这个错误的发生。
  • pickle_module, pickle_load_args基本上不需要进行任何的设置。

加载好模型之后把模型上传到设备上。 

# 加载模型
net = torch.load('./LeNet.pkl',map_location=torch.device(device))
# 上传模型到设备
net.to(device)

3 关闭梯度

测试阶段不需要对模型的参数进行更新,可以关闭自动求导功能;并使用net.eval()方法屏蔽Dropout层、冻结BN层的参数,防止在测试阶段BN层发生参数更新

# 关闭梯度
torch.set_grad_enabled(False)
# 开启验证
net.eval()

测试及输出结果

# 获取数据及大小
length = test_data.data.size(0)
# 使用for将数据输入模型并且获得输出
for i,data in enumerate(test_loader):
    # 获得数据和标注
    x,y = data
    # 模型输出
    y_pred = net(x.to(device,torch.float))
    # 获得预测的标签
    pred = y_pred.argmax(dim=1)
    # 统计预测正确的数量
    acc += (pred.data.cpu() == y.data).sum()
    # 每一次预测后输出其预测的结果和对应的真实值
    print('Predict:', int(pred.data.cpu()), '|Ground Truth:', int(y.data))
    # 计算模型在测试集上的准确率
acc = (acc / length) * 100
# 转换为百分比的形式
print('Accuracy: %.2f' % acc, '%')

输出结果如下所示: 

  • 5
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值