代码:
def visualize_model(model, num_images=6):
model_training = model.training
model.eval()
num = 0
plt.figure()
with torch.no_grad():
for images, labels in dataloaders['val']:
out = model(images)
_, pred = torch.max(out, 1)
for j in range(images.size()[0]):
num += 1
ax = plt.subplot(num_images//2, 2,num)
ax.axis('off')
ax.set_title('predicted: {}'.format(class_name[pred[j]]))
imshow(images[j])
if num == num_images:
model.training = model_training
return
model.training = model_training