目的:利用已经训练好的网络模型,输入数据进行测试。(相当于将其运用于真实场景中)
Cifar-10分类:
test1:放入一张狗狗的图片
进行验证:test.py
import torch
import torchvision
from PIL import Image
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
img_path = './imgs/1.png'
img = Image.open(img_path)
# print(img) # <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=153x153 at 0x24C0AA94820>
# 注意:png是四个通道,除了RGB外还有一个透明通道。
img = img.convert('RGB')
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((32, 32)),
torchvision.transforms.ToTensor()
])
img = transform(img)
# print(img.shape) # torch.Size([3, 32, 32])
model = torch.load('./model_save/myModule_0.pth')
print(model)
img = torch.reshape(img, (1, 3, 32, 32))
model.eval()
with torch.no_grad():
output = model(img)
print(output)
print(output.argmax(1))
MyModule(
(model): Sequential(
(0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Flatten(start_dim=1, end_dim=-1)
(7): Linear(in_features=1024, out_features=64, bias=True)
(8): Linear(in_features=64, out_features=10, bias=True)
)
)
tensor([[-1.8267, -0.0232, 0.6100, 0.9207, 0.7240, 1.2317, 0.8643, 0.2022,
-2.3572, -1.3339]])
tensor([5])Process finished with exit code 0
预测结果:正确
test2:放入一张飞机的图片
tensor([[ 0.1496, 0.5647, 0.4304, 0.0127, -0.0547, -0.0984, -0.3994, 0.4792,
-0.1503, 0.4344]])
tensor([1])
预测结果:错误
因为我们加载的是“myModule_0.pth”,还没有进行太多轮训练。
可以使用训练过多轮后的模型,再进行加载。
test.py
import torch
import torchvision
from PIL import Image
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
img_path = './imgs/2.png'
img = Image.open(img_path)
# print(img) # <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=153x153 at 0x24C0AA94820>
# 注意:png是四个通道,除了RGB外还有一个透明通道。
img = img.convert('RGB')
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((32, 32)),
torchvision.transforms.ToTensor()
])
img = transform(img)
# print(img.shape) # torch.Size([3, 32, 32])
model = torch.load('./model_save/myModule_9.pth')
print(model)
img = torch.reshape(img, (1, 3, 32, 32))
model.eval()
with torch.no_grad():
output = model(img)
print(output)
print(output.argmax(1))
tensor([[ 3.3083, -2.0488, 2.3623, -0.2181, 1.3219, 0.0711, -3.2308, 0.4749,
-0.5960, -1.3306]])
tensor([0])
预测结果:正确