导包
import torch
import torch. nn as nn
import torchvision
import matplotlib. pyplot as plt
import numpy as np
import torch. nn. functional as F
from torchinfo import summary
import warnings
设置数据集
train_ds = torchvision. datasets. MNIST( 'data' ,
train= True ,
transform= torchvision. transforms. ToTensor( ) ,
download = True )
test_ds = torchvision. datasets. MNIST( 'data' ,
train= False ,
transform= torchvision. transforms. ToTensor( ) ,
download = True )
构建CNN网络
num_class = 10
class Model ( nn. Module) :
def __init__ ( self) :
super ( ) . __init__( )
self. conv1 = nn. Conv2d( 1 , 32 , kernel_size= 3 )
self. pool1 = nn. MaxPool2d( 2 )
self. conv2 = nn. Conv2d( 32 , 64 , kernel_size= 3 )
self. pool2 = nn. MaxPool2d( 2 )
self. fc1 = nn. Linear( 1600 , 64 )
self. fc2 = nn. Linear( 64 , num_class)
def forward ( self, x) :
x = self. pool1( F. relu( self. conv1( x) ) )
x = self. pool2( F. relu( self. conv2( x) ) )
x = torch. flatten( x, start_dim= 1 )
x = F. relu( self. fc1( x) )
x = self. fc2( x)
return x
训练模型
设置参数
loss_fn = nn. CrossEntropyLoss( )
learn_rate = 1e-2
opt = torch. optim. SGD( model. parameters( ) , lr = learn_rate)
训练函数
def train ( dataloader, model, loss_fn, optimizer) :
size = len ( dataloader. dataset)
num_batches = len ( dataloader)
train_loss, train_acc = 0 , 0
for X, y in dataloader:
X, y = X. to( device) , y. to( device)
pred = model( X)
loss = loss_fn( pred, y)
optimizer. zero_grad( )
loss. backward( )
optimizer. step( )
train_acc += ( pred. argmax( 1 ) == y) . type ( torch. float ) . sum ( ) . item( )
train_loss += loss. item( )
train_acc /= size
train_loss /= num_batches
return train_acc, train_loss
测试函数
def test ( dataloader, model, loss_fn) :
size = len ( dataloader. dataset)
num_batches = len ( dataloader)
test_loss, train_acc = 0 , 0
with torch. no_grad( ) :
for imgs, target in dataloader:
imgs, target = imgs. to( device) , target. to( device)
target_pred = model( imgs)
loss = loss_fn( target_pred, target)
train_acc += ( target_pred. argmax( 1 ) == target) . type ( torch. float ) . sum ( ) . item( )
test_loss += loss. item( )
test_acc /= size
test_loss /= num_batches
return test_acc, test_loss
数据可视化
warnings. filterwarnings( 'ignore' )
plt. rcParams[ 'font.sans-serif' ] = [ 'SimHei' ]
plt. rcParams[ 'axes.unicode_minus' ] = False
plt. rcParams[ 'figure.dpi' ] = 100
epochs_range = range ( epochs)
plt. figure( figsize= ( 12 , 3 ) )
plt. subplot( 1 , 2 , 1 )
plt. plot( epochs_range, train_acc, label= 'Training Accuracy' )
plt. plot( epochs_range, test_acc, label= 'Test Accuracy' )
plt. legend( loc= 'lower right' )
plt. title( 'Training and Validation Accuracy' )
plt. subplot( 1 , 2 , 2 )
plt. plot( epochs_range, train_loss, label= 'Training Loss' )
plt. plot( epochs_range, test_loss, label= 'Test Loss' )
plt. legend( loc= 'upper right' )
plt. title( 'Training and Validation Loss' )
plt. show( )