完整的模型验证(测试,demo)套路:利用已经训练好的模型,然后给它提供输入(对外应用)
import torch
import torchvision.transforms
from PIL import Image
from torch import nn
# 找到img.png图片----》相对路径的考察
img_path = "../image/airplane.png"
# 读取图片
img = Image.open(img_path)
print(img) # 输出:<PIL.PngImagePlugin.PngImageFile image mode=RGB size=650x558 at 0x23C3CD62160>
# 已知 模型输入为32*32,所以需要对此图片进行变化
# 首先对png的通道数进行变化,png格式是四个通道,除RGB三通道外,还有一个透明通道
img = img.convert('RGB')
# 然后对img size进行改变
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
torchvision.transforms.Resize((32, 32))])
img = transform(img)
print(img.shape) # 输出torch.Size([3, 32, 32])
# 拷贝网络模型
class Peipei(nn.Module):
def __init__(self):
super(Peipei, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 32, 5, 1, 2),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, 5, 1, 2),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, 1, 2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64 * 4 * 4, 64),
nn.Linear(64, 10)
)
def forward(self, x):
x = self.model(x)
return x
# 加载网络模型
model = Peipei()
model = torch.load("../peipei_39.pth", map_location=torch.device('cpu'))
# model.load_state_dict(torch.load("../peipei_39.pth"))
print(model)
# img输入模型之中
img = torch.reshape(img, (1, 3, 32, 32))
model.eval()
with torch.no_grad():
output = model(img)
print(output)
# 输出预测类别
print(output.argmax(1))
输出:
由tensor([0])可知,模型预测正确