卷积神经网络实现手写数字识别
导入相关的库
import torch
import torch. nn as nn
from torch. autograd import Variable
import torch. utils. data as Data
import torchvision
import matplotlib. pyplot as plt
设置相应的超参数
EPOCH= 1
BATCH_SIZE= 50
LR= 0.001
DOWNLOAD_MNIST= False
加载训练数据
train_data= torchvision. datasets. MNIST(
root= './mnist' ,
train= True ,
transform= torchvision. transforms. ToTensor( ) ,
download= True
)
任意显示一个图片
print ( train_data. data. size( ) )
print ( train_data. targets. size( ) )
plt. imshow( train_data. data[ 0 ] . numpy( ) , cmap= 'gray' )
plt. title( '%i' % train_data. targets[ 0 ] )
plt. show( )
torch.Size([60000, 28, 28])
torch.Size([60000])
设置训练集和测试集
train_loader= Data. DataLoader( dataset= train_data, batch_size= BATCH_SIZE, shuffle= True )
test_data= torchvision. datasets. MNIST(
root= './mnist' ,
train= False ,
)
取数据并处理
with torch. no_grad( ) :
test_x= Variable( torch. unsqueeze( test_data. data, dim= 1 ) ) . type ( torch. FloatTensor) [ : 2000 ] / 255
test_y= test_data. targets[ : 2000 ]
卷积网络的定义
class CNN ( nn. Module) :
def __init__ ( self) :
super ( CNN, self) . __init__( )
self. conv1= nn. Sequential(
nn. Conv2d(
in_channels= 1 ,
out_channels= 16 ,
kernel_size= 5 ,
stride= 1 ,
padding= 2 ,
) ,
nn. ReLU( ) ,
nn. MaxPool2d( kernel_size= 2 ) ,
)
self. conv2= nn. Sequential(
nn. Conv2d(
in_channels= 16 ,
out_channels= 32 ,
kernel_size= 5 ,
stride= 1 ,
padding= 2 ,
) ,
nn. ReLU( ) ,
nn. MaxPool2d( kernel_size= 2 ) ,
)
self. out= nn. Linear( 32 * 7 * 7 , 10 )
def forward ( self, x) :
x= self. conv1( x)
x= self. conv2( x)
x= x. view( x. size( 0 ) , - 1 )
output= self. out( x)
return output
卷积神经网络的结构
cnn= CNN( )
print ( cnn)
CNN(
(conv1): Sequential(
(0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(conv2): Sequential(
(0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(out): Linear(in_features=1568, out_features=10, bias=True)
)
训练以及预测
optimizer= torch. optim. Adam( cnn. parameters( ) , lr= LR)
loss_fn= nn. CrossEntropyLoss( )
'''
开始训练我们的模型哦
'''
step= 0
for epoch in range ( EPOCH) :
for step, data in enumerate ( train_loader) :
x, y= data
b_x= Variable( x)
b_y= Variable( y)
output= cnn( b_x)
loss= loss_fn( output, b_y)
optimizer. zero_grad( )
loss. backward( )
optimizer. step( )
if ( step% 50 == 0 ) :
test_output= cnn( test_x)
y_pred= torch. max ( test_output, 1 ) [ 1 ] . data. squeeze( )
accuracy= sum ( y_pred== test_y) . item( ) / test_y. size( 0 )
print ( 'now epoch : ' , epoch, ' | loss : %.4f ' % loss. item( ) , ' | accuracy : ' , accuracy)
test_output= cnn( test_x[ : 10 ] )
y_pred= torch. max ( test_output, 1 ) [ 1 ] . data. squeeze( )
print ( y_pred. tolist( ) , 'predecton Result' )
print ( test_y[ : 10 ] . tolist( ) , 'Real Result' )
now epoch : 0 | loss : 2.3108 | accuracy : 0.18
now epoch : 0 | loss : 0.5315 | accuracy : 0.838
now epoch : 0 | loss : 0.2425 | accuracy : 0.891
now epoch : 0 | loss : 0.3464 | accuracy : 0.921
now epoch : 0 | loss : 0.2871 | accuracy : 0.9405
now epoch : 0 | loss : 0.1521 | accuracy : 0.937
now epoch : 0 | loss : 0.1424 | accuracy : 0.952
now epoch : 0 | loss : 0.0820 | accuracy : 0.955
now epoch : 0 | loss : 0.0958 | accuracy : 0.9585
now epoch : 0 | loss : 0.0694 | accuracy : 0.9615
now epoch : 0 | loss : 0.1700 | accuracy : 0.9645
now epoch : 0 | loss : 0.0583 | accuracy : 0.9595
now epoch : 0 | loss : 0.1821 | accuracy : 0.9645
now epoch : 0 | loss : 0.0640 | accuracy : 0.966
now epoch : 0 | loss : 0.1602 | accuracy : 0.9735
now epoch : 0 | loss : 0.0350 | accuracy : 0.972
now epoch : 0 | loss : 0.0863 | accuracy : 0.9585
now epoch : 0 | loss : 0.1025 | accuracy : 0.9655
now epoch : 0 | loss : 0.0898 | accuracy : 0.972
now epoch : 0 | loss : 0.0819 | accuracy : 0.976
now epoch : 0 | loss : 0.0593 | accuracy : 0.976
now epoch : 0 | loss : 0.0615 | accuracy : 0.9745
now epoch : 0 | loss : 0.0529 | accuracy : 0.978
now epoch : 0 | loss : 0.0372 | accuracy : 0.977
[7, 2, 1, 0, 4, 1, 4, 9, 5, 9] predecton Result
[7, 2, 1, 0, 4, 1, 4, 9, 5, 9] Real Result