import torch
import torchvision
import torchvision.transforms as transforms
######################################################################### The output of torchvision datasets are PILImage images of range [0, 1].# We transform them to Tensors of normalized range [-1, 1].
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
classes =('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
2. 定义一个卷积神经网络
import torch.nn as nn
import torch.nn.functional as F
classNet(nn.Module):def__init__(self):super(Net, self).__init__()
self.conv1 = nn.Conv2d(3,6,5)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)defforward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1,16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)return x
for epoch inrange(2):# loop over the dataset multiple times
running_loss =0.0for i, data inenumerate(trainloader,0):#enumerate(sequence, [start=0])sequence -- 一个序列、迭代器或其他支持迭代对象。start -- 下标起始位置。# get the inputs
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()# print statistics
running_loss += loss.item()if i %2000==1999:# print every 2000 mini-batchesprint('[%d, %5d] loss: %.3f'%(epoch +1, i +1, running_loss /2000))
running_loss =0.0print('Finished Training')
5. 对测试数据进行网络测试
#测试所有图片分类的准确性
correct =0
total =0with torch.no_grad():for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data,1)
total += labels.size(0)
correct +=(predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %d %%'%(100* correct / total))#测试某一类图片的分类准确性
class_correct =list(0.for i inrange(10))
class_total =list(0.for i inrange(10))with torch.no_grad():for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs,1)
c =(predicted == labels).squeeze()for i inrange(4):
label = labels[i]
class_correct[label]+= c[i].item()
class_total[label]+=1for i inrange(10):print('Accuracy of %5s : %2d %%'%(
classes[i],100* class_correct[i]/ class_total[i]))