pytorch学习第四课-卷积网络-MNIST识别

上次用全连接网络,实现了MNIST手写字体的识别,validation的准确度为96%左右,现在我们使用卷积神经网络来看看最后的识别准确度,这里我们用LeNet-5来进行MNIST手写字的识别:

首先还是加载数据集:

import torch
import torch.nn as nn
from torchvision import datasets,transforms
import numpy as np
import matplotlib.pyplot as plt


transform=transforms.Compose([transforms.Resize([28,28])
							  transforms.ToTensor(),
							  transforms.Normalize((0.5,),(0.5,))])
train_dataset=datasets.MNIST('./data',train=True,downlowd=True,transform=transform)
validation_dataset=datasets.MNIST('./data',train=False,downlowd=True,transform=transform)

training_loader=torch.utils.data.DataLoader(datasets=train_dataset,batchsize=100,shuffle=True)
validation_loader=torch.utils.data.DataLoader(datasets=validation_dataset,batchsize=100,shuffle=False)

然后定义我们的网络,LeNet-5的网络结构如下图:
在这里插入图片描述
两个卷积层,后面接max_pooling,然后两个全连接层,这里我们自己设置相关的参数

class LeNet(nn.Module):
	def __init__(self,d_out):
		super().__init__();
		self.conv1=nn.Conv2d(1,32,3,1,padding=1)
		self.pool1=nn.MaxPool2d((2,2))
		self.conv2=nn.Conv2d(32,64,3,1,padding=1)
		self.pool2=nn.MaxPool2d((2,2))
		# 28*28,经过两次卷积之后,图片变为 28/(2*2)=7
		self.fc1=nn.Linear(7*7*64,500)
		self.fc2=nn.Linear(500,d_out)
	def forward(self,x):
		x=torch.relu(self.conv1(x))
		x=self.pool1(x)
		x=torch.relu(self.conv2(x))
		x=self.pool2(x)
		
		x=x.view(-1,7*7*64)
		x=torch.relu(self.fc1(x))
		x=self.fc2(x) #这里要用 nn.CrossEntropyLoss,所以不用softmax

		return x
	

最后就是开始训练:

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=LeNet(10).to(device)

criterion=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)

epoches=12

poch=12
losses=[]

val_losses=[]
val_acces=[]

for i in range(epoch):
    
    running_loss=0.0
    running_correct=0.0
    
   
    for steps,(inputs,labels) in enumerate(training_loader):

        inputs=inputs.to(device)
        labels=labels.to(device)
        
        y_pred=model.forward(inputs)

        loss=criterion(y_pred,labels)
        #print(loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        _,preds=torch.max(y_pred,1)
        running_correct+=torch.sum(preds==labels.data)
        running_loss=running_loss+loss.item()
        
  
    running_correct=running_correct/inputs.shape[0]
    running_loss=running_loss/inputs.shape[0]
    epoch_loss=running_loss/len(training_loader)
    epoch_acc=running_correct/len(training_loader)
    losses.append(epoch_loss)
    
    val_running_loss=0.0
    val_running_correct=0.0
        
    val_dataiter=iter(validation_loader)
    with torch.no_grad():
        for j in range(len(validation_loader)): 
            val_inputs,val_labels=val_dataiter.next()
            #val_inputs=val_inputs.view(val_inputs.shape[0],-1)
            val_inputs=val_inputs.to(device)
            val_labels=val_labels.to(device)
            #if use_gpu:
            #    val_inputs=val_inputs.cuda()
            #    val_labels=val_labels.cuda()

            val_pred=model.forward(val_inputs)

            valloss=criterion(val_pred,val_labels)

            _,preds=torch.max(val_pred,1)
            val_running_correct+=torch.sum(preds==val_labels.data)
            val_running_loss=val_running_loss+valloss.item()
    
        val_running_correct=val_running_correct/val_inputs.shape[0]
        val_running_loss=val_running_loss/val_inputs.shape[0]
        val_loss=val_running_loss/len(validation_loader)
        val_acc=val_running_correct/len(validation_loader)
        val_losses.append(val_loss)
        val_acces.append(val_acc)
    
    print(i,' traing loss:',epoch_loss,'epoch_acc:',epoch_acc.item())
    print(i,' validation loss:',val_loss,'validation_acc:',val_acc.item())
    
    plt.close()
    plt.plot(losses,label='training_loss')
    plt.plot(val_losses,label='validation_loss')
    plt.legend()
    plt.savefig('loss.png')
    #model.eval()
    # - 或者 -
    #model.train
    dir='./convolution_checkpoint'
    PATH=os.path.join(dir,'model'+str(i)+'.pth')
    torch.save(model.state_dict(), PATH)

最终validation 集的准确度能达到 99%

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值