1、准备测试图片:
import torchvision
from PIL import Image
import torch
device = torch.device('cuda:0')
image_path = 'image/airplane.png'
image = Image.open(image_path)
image = image.convert('RGB')
print(image)
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),
torchvision.transforms.ToTensor()])
image = transform(image)
print(image.shape)
2、引入网络:
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, ReLU
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.model = nn.Sequential(
Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
MaxPool2d(kernel_size=2),
Conv2d(32, 32, kernel_size=5, stride=1, padding=2),
MaxPool2d(kernel_size=2),
Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
MaxPool2d(kernel_size=2),
Flatten(),
Linear(1024, 64),
ReLU(),
Linear(64, 10),
)
def forward(self, x):
x = self.model(x)
return x
3、引入模型参数:
model = torch.load('tudui_39.pth')
print(model)
4、修改图像尺寸:
image = torch.reshape(image,(1,3,32,32))
5、将图片传至GPU:
image = image.to(device)
6、将模型切换至测试模式:
model.eval()
with torch.no_grad():
output = model(image)
7、输出结果:
print(output)
print(output.argmax(1).item())
结果为:
是飞机。