CIFAR10具体分类网上很多就懒得写了,这里写一下测试图像分类效果代码,以后再研究
import matplotlib.pyplot as plt
import numpy as np
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
poto=40
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=poto, #设置转载图片个数 shuffle=False, num_workers=2)
net2= torch.load('net.pkl')
classes = ('飞机', '汽车', '鸟', '猫',
'鹿', '狗', '青蛙', '马', '船', '卡车')
def imshow(img):
img = img /2+0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
dataiter = iter(testloader) #遍历并装载神经网络
# show images
images, labels = dataiter.next() #提取标签值及图像
preds = net2(images) #神经网络计算结果
predicted = torch.argmax(preds,1) #计算当前最高准确率
for i in range(poto):
print(' '.join(classes[predicted[i]]))
img = torchvision.utils.make_grid(images[i]).numpy()
plt.imshow(np.transpose(img/2+0.5,(1,2,0)))
plt.show()
识别效果