卷积神经网络之手写数字识别
0.导入包
import torch
from torchvision import transforms, datasets
from torch. utils. data import DataLoader
import torch. nn. functional as F
import torch. optim as optim
1.准备数据集
batch_size = 64
transform = transforms. Compose( [ transforms. ToTensor( ) , transforms. Normalize( ( 0.1307 , ) , ( 0.3081 , ) ) ] )
train_dataset = datasets. MNIST( root= './data/' , train= True , download= True , transform= transform)
train_loader = DataLoader( train_dataset, shuffle= True , batch_size= batch_size)
test_dataset = datasets. MNIST( root= './data/' , train= False , download= True , transform= transform)
test_loader = DataLoader( train_dataset, shuffle= False , batch_size= batch_size)
2.设计模型
class Net ( torch. nn. Module) :
def __init__ ( self) :
super ( Net, self) . __init__( )
self. conv1 = torch. nn. Conv2d( 1 , 10 , kernel_size= 5 )
self. conv2 = torch. nn. Conv2d( 10 , 20 , kernel_size= 5 )
self. pooling = torch. nn. MaxPool2d( 2 )
self. fc = torch. nn. Linear( 320 , 10 )
def forward ( self, x) :
batch_size = x. size( 0 )
x = F. relu( self. pooling( self. conv1( x) ) )
x = F. relu( self. pooling( self. conv2( x) ) )
x = x. view( batch_size, - 1 )
x = self. fc( x)
return x
model = Net( )
3.构造损失和优化器
criterion = torch. nn. CrossEntropyLoss( )
optimizer = optim. SGD( model. parameters( ) , lr= 0.01 , momentum= 0.5 )
4.训练
def train ( epoch) :
running_loss = 0.0
for batch_idx, data in enumerate ( train_loader, 0 ) :
inputs, target = data
optimizer. zero_grad( )
outputs = model( inputs)
loss = criterion( outputs, target)
loss. backward( )
optimizer. step( )
running_loss += loss. item( )
if batch_idx % 300 == 299 :
print ( f'epoch: { epoch+ 1 } ,batch_idx: { batch_idx+ 1 } ,loss: { running_loss/ ( batch_idx+ 1 ) : .3f } ' )
5.测试
def test ( ) :
correct = 0
total = 0
with torch. no_grad( ) :
for data in test_loader:
images, labels = data
outputs = model( images)
_, predicted = torch. max ( outputs. data, dim= 1 )
total += labels. size( 0 )
correct += ( predicted == labels) . sum ( ) . item( )
print ( f'正确率: { 100 * correct/ total} %' )
6.执行
if __name__ == '__main__' :
for epoch in range ( 10 ) :
train( epoch)
test( )
epoch:1,batch_idx:300,loss:0.716
epoch:1,batch_idx:600,loss:0.887
epoch:1,batch_idx:900,loss:1.019
正确率:96.68166666666667%
epoch:2,batch_idx:300,loss:0.102
epoch:2,batch_idx:600,loss:0.191
epoch:2,batch_idx:900,loss:0.284
正确率:97.77333333333333%
epoch:3,batch_idx:300,loss:0.077
epoch:3,batch_idx:600,loss:0.151
epoch:3,batch_idx:900,loss:0.218
正确率:97.92833333333333%
epoch:4,batch_idx:300,loss:0.063
epoch:4,batch_idx:600,loss:0.120
epoch:4,batch_idx:900,loss:0.183
正确率:98.37166666666667%
epoch:5,batch_idx:300,loss:0.056
epoch:5,batch_idx:600,loss:0.109
epoch:5,batch_idx:900,loss:0.159
正确率:98.62833333333333%
epoch:6,batch_idx:300,loss:0.051
epoch:6,batch_idx:600,loss:0.097
epoch:6,batch_idx:900,loss:0.144
正确率:98.66666666666667%
epoch:7,batch_idx:300,loss:0.042
epoch:7,batch_idx:600,loss:0.086
epoch:7,batch_idx:900,loss:0.131
正确率:98.725%
epoch:8,batch_idx:300,loss:0.040
epoch:8,batch_idx:600,loss:0.079
epoch:8,batch_idx:900,loss:0.122
正确率:98.97166666666666%
epoch:9,batch_idx:300,loss:0.036
epoch:9,batch_idx:600,loss:0.077
epoch:9,batch_idx:900,loss:0.113
正确率:98.78833333333333%
epoch:10,batch_idx:300,loss:0.038
epoch:10,batch_idx:600,loss:0.070
epoch:10,batch_idx:900,loss:0.106
正确率:99.10166666666667%