1.导库
import time
import os
import numpy as np
import torch
import torch. nn. functional as F
import torch. nn as nn
from torch. utils. data import Dataset
from torch. utils. data import DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib. pyplot as plt
import argparse
import numpy as np
import torch
import torchvision
import torch. nn as nn
import torch. nn. functional as F
import torch. utils. data as Data
import matplotlib. pyplot as plt
from PIL import Image
import torch. nn as nn
import torch. nn. functional as F
import torch. nn as nn
import torch. nn. functional as F
from torchvision import datasets
DEVICE = torch. device( "cuda" if torch. cuda. is_available( ) else "cpu" )
2.获取数据集
BATCH_SIZE = 64
train_dataset = datasets. MNIST( root= 'data' ,
train= True ,
transform= transforms. ToTensor( ) ,
download= True )
test_dataset = datasets. MNIST( root= 'data' ,
train= False ,
transform= transforms. ToTensor( ) )
train_loader = DataLoader( dataset= train_dataset,
batch_size= BATCH_SIZE,
shuffle= True )
test_loader = DataLoader( dataset= test_dataset,
batch_size= BATCH_SIZE,
shuffle= False )
3.创建Resnet18模型
def conv3x3 ( in_planes, out_planes, stride= 1 ) :
"""3x3 convolution with padding"""
return nn. Conv2d( in_planes, out_planes, kernel_size= 3 , stride= stride,
padding= 1 , bias= False )
class BasicBlock ( nn. Module) :
expansion = 1
def __init__ ( self, inplanes, planes, stride= 1 , downsample= None ) :
super ( BasicBlock, self) . __init__( )
self. conv1 = conv3x3( inplanes, planes, stride)
self. bn1 = nn. BatchNorm2d( planes)
self. relu = nn. ReLU( inplace= True )
self. conv2 = conv3x3( planes, planes)
self. bn2 = nn. BatchNorm2d( planes)
self. downsample = downsample
self. stride = stride
def forward ( self, x) :
residual = x
out = self. conv1( x)
out = self. bn1( out)
out = self. relu( out)
out = self. conv2( out)
out = self. bn2( out)
if self. downsample is not None :
residual = self. downsample( x)
out += residual
out = self. relu( out)
return out
class ResNet ( nn. Module) :
def __init__ ( self, block, layers, num_classes, grayscale) :
self. inplanes = 64
if grayscale:
in_dim = 1
else :
in_dim = 3
super ( ResNet, self) . __init__( )
self. conv1 = nn. Conv2d( in_dim, 64 , kernel_size= 7 , stride= 2 , padding= 3 ,
bias= False )
self. bn1 = nn. BatchNorm2d( 64 )
self. relu = nn. ReLU( inplace= True )
self. maxpool = nn. MaxPool2d( kernel_size= 3 , stride= 2 , padding= 1 )
self. layer1 = self. _make_layer( block, 64 , layers[ 0 ] )
self. layer2 = self. _make_layer( block, 128 , layers[ 1 ] , stride= 2 )
self. layer3 = self. _make_layer( block, 256 , layers[ 2 ] , stride= 2 )
self. layer4 = self. _make_layer( block, 512 , layers[ 3 ] , stride= 2 )
self. avgpool = nn. AvgPool2d( 7 , stride= 1 )
self. fc = nn. Linear( 512 * block. expansion, num_classes)
def _make_layer ( self, block, planes, blocks, stride= 1 ) :
downsample = None
if stride != 1 or self. inplanes != planes * block. expansion:
downsample = nn. Sequential(
nn. Conv2d( self. inplanes, planes * block. expansion,
kernel_size= 1 , stride= stride, bias= False ) ,
nn. BatchNorm2d( planes * block. expansion) ,
)
layers = [ ]
layers. append( block( self. inplanes, planes, stride, downsample) )
self. inplanes = planes * block. expansion
for i in range ( 1 , blocks) :
layers. append( block( self. inplanes, planes) )
return nn. Sequential( * layers)
def forward ( self, x) :
x = self. conv1( x)
x = self. bn1( x)
x = self. relu( x)
x = self. maxpool( x)
x = self. layer1( x)
x = self. layer2( x)
x = self. layer3( x)
x = self. layer4( x)
x = x. view( x. size( 0 ) , - 1 )
logits = self. fc( x)
probas = F. softmax( logits, dim= 1 )
return logits, probas
def resnet18 ( num_classes) :
"""Constructs a ResNet-18 model."""
model = ResNet( block= BasicBlock,
layers= [ 2 , 2 , 2 , 2 ] ,
num_classes= num_classes,
grayscale= True )
return model
net = resnet18( 10 )
print ( net)
简单查看一下网络结构
ResNet(
( conv1) : Conv2d( 1 , 64 , kernel_size= ( 7 , 7 ) , stride= ( 2 , 2 ) , padding= ( 3 , 3 ) , bias= False )
( bn1) : BatchNorm2d( 64 , eps= 1e - 05 , momentum= 0.1 , affine= True , track_running_stats= True )
( relu) : ReLU( inplace= True )
( maxpool) : MaxPool2d( kernel_size= 3 , stride= 2 , padding= 1 , dilation= 1 , ceil_mode= False )
( layer1) : Sequential(
( 0 ) : BasicBlock(
( conv1) : Conv2d( 64 , 64 , kernel_size= ( 3 , 3 ) , stride= ( 1 , 1 ) , padding= ( 1 , 1 ) , bias= False )
( bn1) : BatchNorm2d( 64 , eps= 1e - 05 , momentum= 0.1 , affine= True , track_running_stats= True )
( relu) : ReLU( inplace= True )
( conv2) : Conv2d( 64 , 64 , kernel_size= ( 3 , 3 ) , stride= ( 1 , 1 ) , padding= ( 1 , 1 ) , bias= False )
( bn2) : BatchNorm2d( 64 , eps= 1e - 05 , momentum= 0.1 , affine= True , track_running_stats= True )
)
( 1 ) : BasicBlock(
( conv1) : Conv2d( 64 , 64 , kernel_size= ( 3 , 3 ) , stride= ( 1 , 1 ) , padding= ( 1 , 1 ) , bias= False )
( bn1) : BatchNorm2d( 64 , eps= 1e - 05 , momentum= 0.1 , affine= True , track_running_stats= True )
( relu) : ReLU( inplace= True )
( conv2) : Conv2d( 64 , 64 , kernel_size= ( 3 , 3 ) , stride= ( 1 , 1 ) , padding= ( 1 , 1 ) , bias= False )
( bn2) : BatchNorm2d( 64 , eps= 1e - 05 , momentum= 0.1 , affine= True , track_running_stats= True )
)
)
( layer2) : Sequential(
( 0 ) : BasicBlock(
( conv1) : Conv2d( 64 , 128 , kernel_size= ( 3 , 3 ) , stride= ( 2 , 2 ) , padding= ( 1 , 1 ) , bias= False )
( bn1) : BatchNorm2d( 128 , eps= 1e - 05 , momentum= 0.1 , affine= True , track_running_stats= True )
( relu) : ReLU( inplace= True )
( conv2) : Conv2d( 128 , 128 , kernel_size= ( 3 , 3 ) , stride= ( 1 , 1 ) , padding= ( 1 , 1 ) , bias= False )
( bn2) : BatchNorm2d( 128 , eps= 1e - 05 , momentum= 0.1 , affine= True , track_running_stats= True )
( downsample) : Sequential(
( 0 ) : Conv2d( 64 , 128 , kernel_size= ( 1 , 1 ) , stride= ( 2 , 2 ) , bias= False )
( 1 ) : BatchNorm2d( 128 , eps= 1e - 05 , momentum= 0.1 , affine= True , track_running_stats= True )
)
)
( 1 ) : BasicBlock(
( conv1) : Conv2d( 128 , 128 , kernel_size= ( 3 , 3 ) , stride= ( 1 , 1 ) , padding= ( 1 , 1 ) , bias= False )
( bn1) : BatchNorm2d( 128 , eps= 1e - 05 , momentum= 0.1 , affine= True , track_running_stats= True )
( relu) : ReLU( inplace= True )
( conv2) : Conv2d( 128 , 128 , kernel_size= ( 3 , 3 ) , stride= ( 1 , 1 ) , padding= ( 1 , 1 ) , bias= False )
( bn2) : BatchNorm2d( 128 , eps= 1e - 05 , momentum= 0.1 , affine= True , track_running_stats= True )
)
)
( layer3) : Sequential(
( 0 ) : BasicBlock(
( conv1) : Conv2d( 128 , 256 , kernel_size= ( 3 , 3 ) , stride= ( 2 , 2 ) , padding= ( 1 , 1 ) , bias= False )
( bn1) : BatchNorm2d( 256 , eps= 1e - 05 , momentum= 0.1 , affine= True , track_running_stats= True )
( relu) : ReLU( inplace= True )
( conv2) : Conv2d( 256 , 256 , kernel_size= ( 3 , 3 ) , stride= ( 1 , 1 ) , padding= ( 1 , 1 ) , bias= False )
( bn2) : BatchNorm2d( 256 , eps= 1e - 05 , momentum= 0.1 , affine= True , track_running_stats= True )
( downsample) : Sequential(
( 0 ) : Conv2d( 128 , 256 , kernel_size= ( 1 , 1 ) , stride= ( 2 , 2 ) , bias= False )
( 1 ) : BatchNorm2d( 256 , eps= 1e - 05 , momentum= 0.1 , affine= True , track_running_stats= True )
)
)
( 1 ) : BasicBlock(
( conv1) : Conv2d( 256 , 256 , kernel_size= ( 3 , 3 ) , stride= ( 1 , 1 ) , padding= ( 1 , 1 ) , bias= False )
( bn1) : BatchNorm2d( 256 , eps= 1e - 05 , momentum= 0.1 , affine= True , track_running_stats= True )
( relu) : ReLU( inplace= True )
( conv2) : Conv2d( 256 , 256 , kernel_size= ( 3 , 3 ) , stride= ( 1 , 1 ) , padding= ( 1 , 1 ) , bias= False )
( bn2) : BatchNorm2d( 256 , eps= 1e - 05 , momentum= 0.1 , affine= True , track_running_stats= True )
)
)
( layer4) : Sequential(
( 0 ) : BasicBlock(
( conv1) : Conv2d( 256 , 512 , kernel_size= ( 3 , 3 ) , stride= ( 2 , 2 ) , padding= ( 1 , 1 ) , bias= False )
( bn1) : BatchNorm2d( 512 , eps= 1e - 05 , momentum= 0.1 , affine= True , track_running_stats= True )
( relu) : ReLU( inplace= True )
( conv2) : Conv2d( 512 , 512 , kernel_size= ( 3 , 3 ) , stride= ( 1 , 1 ) , padding= ( 1 , 1 ) , bias= False )
( bn2) : BatchNorm2d( 512 , eps= 1e - 05 , momentum= 0.1 , affine= True , track_running_stats= True )
( downsample) : Sequential(
( 0 ) : Conv2d( 256 , 512 , kernel_size= ( 1 , 1 ) , stride= ( 2 , 2 ) , bias= False )
( 1 ) : BatchNorm2d( 512 , eps= 1e - 05 , momentum= 0.1 , affine= True , track_running_stats= True )
)
)
( 1 ) : BasicBlock(
( conv1) : Conv2d( 512 , 512 , kernel_size= ( 3 , 3 ) , stride= ( 1 , 1 ) , padding= ( 1 , 1 ) , bias= False )
( bn1) : BatchNorm2d( 512 , eps= 1e - 05 , momentum= 0.1 , affine= True , track_running_stats= True )
( relu) : ReLU( inplace= True )
( conv2) : Conv2d( 512 , 512 , kernel_size= ( 3 , 3 ) , stride= ( 1 , 1 ) , padding= ( 1 , 1 ) , bias= False )
( bn2) : BatchNorm2d( 512 , eps= 1e - 05 , momentum= 0.1 , affine= True , track_running_stats= True )
)
)
( avgpool) : AvgPool2d( kernel_size= 7 , stride= 1 , padding= 0 )
( fc) : Linear( in_features= 512 , out_features= 10 , bias= True )
)
4.开启训练
NUM_EPOCHS = 10
model = resnet18( num_classes= 10 )
model = model. to( DEVICE)
optimizer = torch. optim. Adam( model. parameters( ) , lr= 0.001 )
valid_loader = test_loader
def compute_accuracy_and_loss ( model, data_loader, device) :
correct_pred, num_examples = 0 , 0
cross_entropy = 0 .
for i, ( features, targets) in enumerate ( data_loader) :
features = features. to( device)
targets = targets. to( device)
logits, probas = model( features)
cross_entropy += F. cross_entropy( logits, targets) . item( )
_, predicted_labels = torch. max ( probas, 1 )
num_examples += targets. size( 0 )
correct_pred += ( predicted_labels == targets) . sum ( )
return correct_pred. float ( ) / num_examples * 100 , cross_entropy/ num_examples
start_time = time. time( )
train_acc_lst, valid_acc_lst = [ ] , [ ]
train_loss_lst, valid_loss_lst = [ ] , [ ]
for epoch in range ( NUM_EPOCHS) :
model. train( )
for batch_idx, ( features, targets) in enumerate ( train_loader) :
features = features. to( DEVICE)
targets = targets. to( DEVICE)
logits, probas = model( features)
cost = F. cross_entropy( logits, targets)
optimizer. zero_grad( )
cost. backward( )
optimizer. step( )
if not batch_idx % 300 :
print ( f'Epoch: { epoch+ 1 : 03d } / { NUM_EPOCHS: 03d } | '
f'Batch { batch_idx: 03d } / { len ( train_loader) : 03d } |'
f' Cost: { cost: .4f } ' )
model. eval ( )
with torch. set_grad_enabled( False ) :
train_acc, train_loss = compute_accuracy_and_loss( model, train_loader, device= DEVICE)
valid_acc, valid_loss = compute_accuracy_and_loss( model, valid_loader, device= DEVICE)
train_acc_lst. append( train_acc)
valid_acc_lst. append( valid_acc)
train_loss_lst. append( train_loss)
valid_loss_lst. append( valid_loss)
print ( f'Epoch: { epoch+ 1 : 03d } / { NUM_EPOCHS: 03d } Train Acc.: { train_acc: .2f } %'
f' | Validation Acc.: { valid_acc: .2f } %' )
elapsed = ( time. time( ) - start_time) / 60
print ( f'Time elapsed: { elapsed: .2f } min' )
elapsed = ( time. time( ) - start_time) / 60
print ( f'Total Training Time: { elapsed: .2f } min' )
训练结果
Epoch: 001 / 010 | Batch 000 / 938 | Cost: 2.5880
Epoch: 001 / 010 | Batch 300 / 938 | Cost: 0.2995
Epoch: 001 / 010 | Batch 600 / 938 | Cost: 0.0254
Epoch: 001 / 010 | Batch 900 / 938 | Cost: 0.0546
Epoch: 001 / 010 Train Acc. : 98.49 % | Validation Acc. : 98.61 %
Time elapsed: 0.50 min
Epoch: 002 / 010 | Batch 000 / 938 | Cost: 0.0118
Epoch: 002 / 010 | Batch 300 / 938 | Cost: 0.0618
Epoch: 002 / 010 | Batch 600 / 938 | Cost: 0.0279
Epoch: 002 / 010 | Batch 900 / 938 | Cost: 0.0147
Epoch: 002 / 010 Train Acc. : 98.89 % | Validation Acc. : 98.82 %
Time elapsed: 1.00 min
Epoch: 003 / 010 | Batch 000 / 938 | Cost: 0.0217
Epoch: 003 / 010 | Batch 300 / 938 | Cost: 0.0488
Epoch: 003 / 010 | Batch 600 / 938 | Cost: 0.0128
Epoch: 003 / 010 | Batch 900 / 938 | Cost: 0.0357
Epoch: 003 / 010 Train Acc. : 99.15 % | Validation Acc. : 98.88 %
Time elapsed: 1.51 min
Epoch: 004 / 010 | Batch 000 / 938 | Cost: 0.1168
Epoch: 004 / 010 | Batch 300 / 938 | Cost: 0.0824
Epoch: 004 / 010 | Batch 600 / 938 | Cost: 0.0131
Epoch: 004 / 010 | Batch 900 / 938 | Cost: 0.0212
Epoch: 004 / 010 Train Acc. : 99.22 % | Validation Acc. : 99.04 %
Time elapsed: 2.00 min
Epoch: 005 / 010 | Batch 000 / 938 | Cost: 0.0037
Epoch: 005 / 010 | Batch 300 / 938 | Cost: 0.0087
Epoch: 005 / 010 | Batch 600 / 938 | Cost: 0.0117
Epoch: 005 / 010 | Batch 900 / 938 | Cost: 0.0067
Epoch: 005 / 010 Train Acc. : 99.24 % | Validation Acc. : 98.84 %
Time elapsed: 2.50 min
Epoch: 006 / 010 | Batch 000 / 938 | Cost: 0.0020
Epoch: 006 / 010 | Batch 300 / 938 | Cost: 0.0038
Epoch: 006 / 010 | Batch 600 / 938 | Cost: 0.0264
Epoch: 006 / 010 | Batch 900 / 938 | Cost: 0.0038
Epoch: 006 / 010 Train Acc. : 99.38 % | Validation Acc. : 99.06 %
Time elapsed: 3.01 min
Epoch: 007 / 010 | Batch 000 / 938 | Cost: 0.0050
Epoch: 007 / 010 | Batch 300 / 938 | Cost: 0.0106
Epoch: 007 / 010 | Batch 600 / 938 | Cost: 0.0508
Epoch: 007 / 010 | Batch 900 / 938 | Cost: 0.0042
Epoch: 007 / 010 Train Acc. : 99.54 % | Validation Acc. : 99.20 %
Time elapsed: 3.50 min
Epoch: 008 / 010 | Batch 000 / 938 | Cost: 0.0098
Epoch: 008 / 010 | Batch 300 / 938 | Cost: 0.0012
Epoch: 008 / 010 | Batch 600 / 938 | Cost: 0.0073
Epoch: 008 / 010 | Batch 900 / 938 | Cost: 0.0016
Epoch: 008 / 010 Train Acc. : 99.47 % | Validation Acc. : 99.19 %
Time elapsed: 4.00 min
Epoch: 009 / 010 | Batch 000 / 938 | Cost: 0.0298
Epoch: 009 / 010 | Batch 300 / 938 | Cost: 0.0670
Epoch: 009 / 010 | Batch 600 / 938 | Cost: 0.0020
Epoch: 009 / 010 | Batch 900 / 938 | Cost: 0.0031
Epoch: 009 / 010 Train Acc. : 99.68 % | Validation Acc. : 99.29 %
Time elapsed: 4.50 min
Epoch: 010 / 010 | Batch 000 / 938 | Cost: 0.0005
Epoch: 010 / 010 | Batch 300 / 938 | Cost: 0.0556
Epoch: 010 / 010 | Batch 600 / 938 | Cost: 0.0012
Epoch: 010 / 010 | Batch 900 / 938 | Cost: 0.0044
Epoch: 010 / 010 Train Acc. : 99.44 % | Validation Acc. : 99.11 %
Time elapsed: 4.99 min
Total Training Time: 4.99 min
训练损失和测试损失关系图
plt. plot( range ( 1 , NUM_EPOCHS+ 1 ) , train_loss_lst, label= 'Training loss' )
plt. plot( range ( 1 , NUM_EPOCHS+ 1 ) , valid_loss_lst, label= 'Validation loss' )
plt. legend( loc= 'upper right' )
plt. ylabel( 'Cross entropy' )
plt. xlabel( 'Epoch' )
plt. show( )
训练精度和测试精度关系图
plt. plot( range ( 1 , NUM_EPOCHS+ 1 ) , train_acc_lst, label= 'Training accuracy' )
plt. plot( range ( 1 , NUM_EPOCHS+ 1 ) , valid_acc_lst, label= 'Validation accuracy' )
plt. legend( loc= 'upper left' )
plt. ylabel( 'Cross entropy' )
plt. xlabel( 'Epoch' )
plt. show( )
5.测试阶段
model. eval ( )
with torch. set_grad_enabled( False ) :
test_acc, test_loss = compute_accuracy_and_loss( model, test_loader, DEVICE)
print ( f'Test accuracy: { test_acc: .2f } %' )
Test accuracy: 99.11 %
6.查看效果图
from PIL import Image
import matplotlib. pyplot as plt
for features, targets in train_loader:
break
_, predictions = model. forward( features[ : 8 ] . to( DEVICE) )
predictions = torch. argmax( predictions, dim= 1 )
print ( predictions)
features = features[ : 7 ]
fig = plt. figure( )
for i in range ( 6 ) :
plt. subplot( 2 , 3 , i+ 1 )
plt. tight_layout( )
tmp = features[ i]
plt. imshow( np. transpose( tmp, ( 1 , 2 , 0 ) ) )
plt. title( "Actual value: {}" . format ( targets[ i] ) + '\n' + "Prediction value: {}" . format ( predictions[ i] ) , size = 10 )
plt. show( )