针对之前根据CIFAR10训练好的模型(30轮,对测试集有67.17%的正确率),并用其对输入图片进行测试
以下是输入图片,文件名为 cat_diandian.jpg:
以下为具体代码:
# 针对训练好的模型,提取出准确率最高的模型,例如My_nn_29_ac=67.17%.pth
# 导入模型
import torchvision
from net_work import *
from PIL import Image
import torch
my_nn = torch.load('./My_nn_29_ac=67.17%.pth')
# 导入图片,比如我养的猫的图
file_path = './cat_diandian.jpg' # 导入路径
image = Image.open(file_path) # 打开图片
# image.show() # 展示图片
# print(image) # 尺寸为227*145 而要求的输入图片的尺寸为32*32,因此要转换
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
torchvision.transforms.ToTensor()])
image = transform(image)
# print(image.shape) # torch.Size([3, 32, 32])
# 由于类接收(N,C,H,W)格式,因此要reshape():
image = torch.reshape(image, (1, 3, 32, 32))
# print(image.shape) # torch.Size([1, 3, 32, 32])
# 使用类进行预测
out = my_nn(image)
class_all = {0: '飞机',
1: '手机',
2: '鸟',
3: '猫',
4: '鹿',
5: '狗',
6: '青蛙',
7: '马',
8: '船',
9: '卡车'}
index = out.argmax().item()
print('预测结果为: {}'.format(class_all[index])) # 预测结果为: 猫
可见训练模型能政策预测图片种类为猫。