上次用全连接网络,实现了MNIST手写字体的识别,validation的准确度为96%左右,现在我们使用卷积神经网络来看看最后的识别准确度,这里我们用LeNet-5来进行MNIST手写字的识别:
首先还是加载数据集:
import torch
import torch.nn as nn
from torchvision import datasets,transforms
import numpy as np
import matplotlib.pyplot as plt
transform=transforms.Compose([transforms.Resize([28,28])
transforms.ToTensor(),
transforms.Normalize((0.5,),(0.5,))])
train_dataset=datasets.MNIST('./data',train=True,downlowd=True,transform=transform)
validation_dataset=datasets.MNIST('./data',train=False,downlowd=True,transform=transform)
training_loader=torch.utils.data.DataLoader(datasets=train_dataset,batchsize=100,shuffle=True)
validation_loader=torch.utils.data.DataLoader(datasets=validation_dataset,batchsize=100,shuffle=False)
然后定义我们的网络,LeNet-5的网络结构如下图:
两个卷积层,后面接max_pooling,然后两个全连接层,这里我们自己设置相关的参数
class LeNet(nn.Module):
def __init__(self,d_out):
super().__init__();
self.conv1=nn.Conv2d(1,32,3,1,padding=1)
self.pool1=nn.MaxPool2d((2,2))
self.conv2=nn.Conv2d(32,64,3,1,padding=1)
self.pool2=nn.MaxPool2d((2,2))
# 28*28,经过两次卷积之后,图片变为 28/(2*2)=7
self.fc1=nn.Linear(7*7*64,500)
self.fc2=nn.Linear(500,d_out)
def forward(self,x):
x=torch.relu(self.conv1(x))
x=self.pool1(x)
x=torch.relu(self.conv2(x))
x=self.pool2(x)
x=x.view(-1,7*7*64)
x=torch.relu(self.fc1(x))
x=self.fc2(x) #这里要用 nn.CrossEntropyLoss,所以不用softmax
return x
最后就是开始训练:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=LeNet(10).to(device)
criterion=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)
epoches=12
poch=12
losses=[]
val_losses=[]
val_acces=[]
for i in range(epoch):
running_loss=0.0
running_correct=0.0
for steps,(inputs,labels) in enumerate(training_loader):
inputs=inputs.to(device)
labels=labels.to(device)
y_pred=model.forward(inputs)
loss=criterion(y_pred,labels)
#print(loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()
_,preds=torch.max(y_pred,1)
running_correct+=torch.sum(preds==labels.data)
running_loss=running_loss+loss.item()
running_correct=running_correct/inputs.shape[0]
running_loss=running_loss/inputs.shape[0]
epoch_loss=running_loss/len(training_loader)
epoch_acc=running_correct/len(training_loader)
losses.append(epoch_loss)
val_running_loss=0.0
val_running_correct=0.0
val_dataiter=iter(validation_loader)
with torch.no_grad():
for j in range(len(validation_loader)):
val_inputs,val_labels=val_dataiter.next()
#val_inputs=val_inputs.view(val_inputs.shape[0],-1)
val_inputs=val_inputs.to(device)
val_labels=val_labels.to(device)
#if use_gpu:
# val_inputs=val_inputs.cuda()
# val_labels=val_labels.cuda()
val_pred=model.forward(val_inputs)
valloss=criterion(val_pred,val_labels)
_,preds=torch.max(val_pred,1)
val_running_correct+=torch.sum(preds==val_labels.data)
val_running_loss=val_running_loss+valloss.item()
val_running_correct=val_running_correct/val_inputs.shape[0]
val_running_loss=val_running_loss/val_inputs.shape[0]
val_loss=val_running_loss/len(validation_loader)
val_acc=val_running_correct/len(validation_loader)
val_losses.append(val_loss)
val_acces.append(val_acc)
print(i,' traing loss:',epoch_loss,'epoch_acc:',epoch_acc.item())
print(i,' validation loss:',val_loss,'validation_acc:',val_acc.item())
plt.close()
plt.plot(losses,label='training_loss')
plt.plot(val_losses,label='validation_loss')
plt.legend()
plt.savefig('loss.png')
#model.eval()
# - 或者 -
#model.train
dir='./convolution_checkpoint'
PATH=os.path.join(dir,'model'+str(i)+'.pth')
torch.save(model.state_dict(), PATH)
最终validation 集的准确度能达到 99%