模型验证
模型验证(测试,demo):利用已经训练好的模型,然后给它提供输入进行测试验证。
import torch
import torchvision.transforms
from PIL import Image
from torch import nn
# 需要测试的图片
image_path = "../imgs/airplane.png"
image = Image.open(image_path)
image = image.convert('RGB') # png图片多了一个透明度通道,修改成rgb三个通道
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
torchvision.transforms.ToTensor()])
image = transform(image)
print(image.shape)
# 引入网络架构
class NNN(nn.Module):
def __init__(self):
super(NNN, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 32, 5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(32, 32, 5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(32, 64, 5, stride=1, padding=2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(1024, 64),
nn.Linear(64, 10)
)
def forward(self, x):
x = self.model(x)
return x
# 读取网络模型 如果保存的模型是通过gpu训练出来的,需要添加 map_location=torch.device("cpu")
model_load = torch.load("NNN_5.pth", map_location=torch.device("cpu"))
# 原有的图片是没有bitch-size的,而我们的输入是需要的
image = torch.reshape(image, (1, 3, 32, 32))
model_load.eval()
with torch.no_grad():
outputs = model_load(image)
print(outputs)
print(outputs.argmax(1))
- 找一张 你需要用训练出来的模型进行测试的图片
- 读取加载你保存的训练模型【用gpu训练的,要加上
map_location
,不然会报错】 - 把图片输入模型进行验证【注意输入图片的格式要求】
- 输出预测结果
outputs.argmax(1)