PyTorch-模型验证

目的:利用已经训练好的网络模型,输入数据进行测试。(相当于将其运用于真实场景中)

Cifar-10分类:

test1:放入一张狗狗的图片

进行验证:test.py

import torch
import torchvision
from PIL import Image
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

img_path = './imgs/1.png'
img = Image.open(img_path)
# print(img)  # <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=153x153 at 0x24C0AA94820>
# 注意:png是四个通道,除了RGB外还有一个透明通道。
img = img.convert('RGB')
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),
    torchvision.transforms.ToTensor()
])
img = transform(img)
# print(img.shape)  # torch.Size([3, 32, 32])
model = torch.load('./model_save/myModule_0.pth')
print(model)
img = torch.reshape(img, (1, 3, 32, 32))
model.eval()
with torch.no_grad():
    output = model(img)
print(output)
print(output.argmax(1))

MyModule(
  (model): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=1024, out_features=64, bias=True)
    (8): Linear(in_features=64, out_features=10, bias=True)
  )
)
tensor([[-1.8267, -0.0232,  0.6100,  0.9207,  0.7240,  1.2317,  0.8643,  0.2022,
         -2.3572, -1.3339]])
tensor([5])

Process finished with exit code 0

预测结果:正确 

test2:放入一张飞机的图片

tensor([[ 0.1496,  0.5647,  0.4304,  0.0127, -0.0547, -0.0984, -0.3994,  0.4792,
         -0.1503,  0.4344]])
tensor([1]) 

预测结果:错误

因为我们加载的是“myModule_0.pth”,还没有进行太多轮训练。

可以使用训练过多轮后的模型,再进行加载。

test.py

import torch
import torchvision
from PIL import Image
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

img_path = './imgs/2.png'
img = Image.open(img_path)
# print(img)  # <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=153x153 at 0x24C0AA94820>
# 注意:png是四个通道,除了RGB外还有一个透明通道。
img = img.convert('RGB')
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),
    torchvision.transforms.ToTensor()
])
img = transform(img)
# print(img.shape)  # torch.Size([3, 32, 32])
model = torch.load('./model_save/myModule_9.pth')
print(model)
img = torch.reshape(img, (1, 3, 32, 32))
model.eval()
with torch.no_grad():
    output = model(img)
print(output)
print(output.argmax(1))

tensor([[ 3.3083, -2.0488,  2.3623, -0.2181,  1.3219,  0.0711, -3.2308,  0.4749,
         -0.5960, -1.3306]])
tensor([0])

预测结果:正确 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值