导入依赖包
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import mnist
from torch import nn
from torch.autograd import Variable
from torch import optim
from torchvision import transforms
定义CNN结构
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1,16,kernel_size=3),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True))
self.layer2 = nn.Sequential(
nn.Conv2d(16,32,kernel_size=3),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2,stride=2))
self.layer3 = nn.Sequential(
nn.Conv2d(32,64,kernel_size=3),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True))
self.layer4 = nn.Sequential(
nn.Conv2d(64,128,kernel_size=3),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2,stride=2))
self.fc = nn.Sequential(
nn.Linear(128 * 4 * 4,1024),
nn.ReLU(inplace=True),
nn.Linear(1024,128),
nn.ReLU(inplace=True),
nn.Linear(128,10))
def forward(self,x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.view(x.size(0),-1)
x = self.fc(x)
return x
下载并导入数据
data_tf = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize([0.5],[0.5])])
train_set = mnist.MNIST('./data',train=True,transform=data_tf,download=True)
test_set = mnist.MNIST('./data',train=False,transform=data_tf,download=True)
train_data = DataLoader(train_set,batch_size=64,shuffle=True)
test_data = DataLoader(test_set,batch_size=128,shuffle=False)
需要科学上网,否则无法下载,该步骤完成后效果图如下:
设置神经网络
net = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),1e-1)
nums_epoch = 20
训练与测试
losses =[]
acces = []
eval_losses = []
eval_acces = []
for epoch in range(nums_epoch):
#TRAIN
train_loss = 0
train_acc = 0
net = net.train()
for img , label in train_data:
#img = img.reshape(img.size(0),-1)
img = Variable(img)
label = Variable(label)
# forward
out = net(img)
loss = criterion(out,label)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
# loss
train_loss += loss.item()
# accuracy
_,pred = out.max(1)
num_correct = (pred == label).sum().item()
acc = num_correct / img.shape[0]
train_acc += acc
losses.append(train_loss / len(train_data))
acces.append(train_acc / len(train_data))
#TEST
eval_loss = 0
eval_acc = 0
for img , label in test_data:
#img = img.reshape(img.size(0),-1)
img = Variable(img)
label = Variable(label)
out = net(img)
loss = criterion(out,label)
eval_loss += loss.item()
_ , pred = out.max(1)
num_correct = (pred==label).sum().item()
acc = num_correct / img.shape[0]
eval_acc += acc
eval_losses.append(eval_loss / len(test_data))
eval_acces.append(eval_acc / len(test_data))
#PRINT IN EVERYEPOCH
print('Epoch {} Train Loss {} Train Accuracy {} Teat Loss {} Test Accuracy {}'.format(
epoch+1, train_loss / len(train_data),train_acc / len(train_data), eval_loss / len(test_data), eval_acc / len(test_data)))
在这里仅显示前两次结果,可以看到准确率已经很高了。