import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(227),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
trainset = torchvision.datasets.CIFAR10(root='./data',train=True, download=True, transform=transform)
trainloader = DataLoader(trainset,batch_size=32,shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data',train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)
classes = ('plane','car','bird','cat','deer','dog','forg','horse','ship','truck')
def imshow(img):
img = img/2 + 0.5
nping = img.numpy()
plt.imshow(np.transpose(nping, (1,2,0)))
plt.show()
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(' '.join('%5s'%classes[labels[j]] for j in range(32)))
class AlexNet(nn.Module):
def __init__(self, num_classes=10):
super(AlexNet, self).__init__()
self.conv1 = nn.Conv2d(3,96,11,4)
self.conv2 = nn.Conv2d(96,256,5,padding=2,groups=2)
self.conv3 = nn.Conv2d(256,384,3,padding=1)
self.conv4 = nn.Conv2d(384,384,3,padding=1, groups=2)
self.conv5 = nn.Conv2d(384,256,3,padding=1, groups=2)
self.fc1 = nn.Linear(256*6*6,4096)
self.fc2 = nn.Linear(4096,4096)
self.fc3 = nn.Linear(4096,num_classes)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
x = F.max_pool2d(F.relu(self.conv2(x)),(2,2))
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = F.max_pool2d(F.relu(self.conv5(x)),(2,2))
x = x.view(x.size(0),256*6*6)
x = F.dropout(F.relu(self.fc1(x)),p=0.5)
x = F.dropout(F.relu(self.fc2(x)),p=0.5)
x = self.fc3(x)
return x
device = torch.device("cuda")
print(device)
net = AlexNet()
net.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for epoch in range(20):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, lables = data
inputs = inputs.to(device)
lables = lables.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, lables)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 200 == 199:
print('[%d,%5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 200))
running_loss = 0.0
print('Finished Training')
torch.save(net, './model')
net = torch.load('./model')
print(net)
dataiter = iter(testloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s'%classes[labels[j]] for j in range(32)))
images, labels = images.to(device), labels.to(device)
outputs = net(images)
predicted = torch.argmax(outputs,1)
print('Predicted: ',' '.join('%5s'%classes[predicted[j]] for j in range(32)))
correct = 0
total = 0
with torch.no_grad():
for data in trainloader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
outputs = net(images)
predicted = torch.argmax(outputs.data,1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 50000 train images: %d %%'%(100*correct/total))
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
outputs = net(images)
predicted = torch.argmax(outputs.data,1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%'%(100*correct/total))
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for data in testloader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
outputs = net(images)
predicted = torch.argmax(outputs.data,1)
c = (predicted== labels)
if len(c) == 16:
for i in range(16):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += c[i].item()
else:
for i in range(32):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(10):
print('Accuracy of %5s : %2d %%'%(classes[i], 100*class_correct[i]/class_total[i]))
结果:
Files already downloaded and verified
Files already downloaded and verified
horse dog forg ship plane plane truck truck truck car deer forg forg horse car horse cat plane car dog plane dog truck truck dog dog bird bird deer dog truck car
cuda
AlexNet(
(conv1): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4))
(conv2): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=2)
(conv3): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv4): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
(conv5): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2)
(fc1): Linear(in_features=9216, out_features=4096, bias=True)
(fc2): Linear(in_features=4096, out_features=4096, bias=True)
(fc3): Linear(in_features=4096, out_features=10, bias=True)
)
GroundTruth: cat ship ship plane forg forg car forg cat car plane truck dog horse truck ship dog horse ship forg horse plane deer truck dog bird deer plane truck forg forg dog
Predicted: cat ship ship plane forg forg cat forg cat truck plane truck dog horse truck ship dog horse ship forg horse plane deer truck deer cat deer plane truck forg forg dog
Accuracy of the network on the 50000 train images: 95 %
Accuracy of the network on the 10000 test images: 76 %
Accuracy of plane : 84 %
Accuracy of car : 87 %
Accuracy of bird : 63 %
Accuracy of cat : 65 %
Accuracy of deer : 72 %
Accuracy of dog : 66 %
Accuracy of forg : 82 %
Accuracy of horse : 82 %
Accuracy of ship : 81 %
Accuracy of truck : 84 %