1.定义数据初始化
image_size= ( 224 , 224 )
import torchvision . transforms as transforms
transform= transforms . Compose ( [
transforms. ToTensor ( ) ,
transforms . RandomHorizontalFlip ( ) ,
transforms . Resize ( image_size ) ,
transforms. Lambda( lambda x: x. repeat( 3 , 1 , 1 ) ) ,
transforms. Normalize( ( 0.1307 , ) , ( 0.3081 , ) )
] )
2.导入数据集
import torchvision . datasets
mnist_train= torchvision . datasets. MNIST( root= '~' , train= True , download= True , transform = transform )
mnist_val= torchvision . datasets . MNIST ( root= '~' , train= False , download= True , transform= transform )
print ( len ( mnist_train. classes) )
print ( len ( mnist_val. classes) )
3.制作DataLoader
from torch . utils . data import DataLoader
trainloader= DataLoader( mnist_train , batch_size= 64 , shuffle= True , num_workers= 2 )
valloader= DataLoader( mnist_val , batch_size= 64 , shuffle= True , num_workers= 2 )
4.调用ResNet18 model
import torchvision . models as models
import torch
model = models. resnet18( pretrained= True )
model. fc= torch. nn. Linear( 512 , 10 )
print ( model)
5.定义优化器
import torch
import torch. nn. init as init
for name, module in model. _modules. items( ) :
if ( name== 'fc' ) :
init. kaiming_uniform_( module. weight, a= 0 , mode= 'fan_in' )
6.调用GPU
device= torch. device( "cuda:0" if torch. cuda. is_available( ) else "cpu" )
print ( device)
7.定义准确率函数
import torch
def accuracy ( pred, target) :
pred_label= torch . argmax( pred, 1 )
correct= sum ( pred_label== target) . to( torch . float )
return correct, len ( pred)
8.定义字典来存放数据
acc= { 'train' : [ ] , "val" : [ ] }
loss_all= { 'train' : [ ] , "val" : [ ] }
9.开始训练和验证
"""设为训练模式"""
model. train( )
train_correctnum, train_prednum, train_total_loss= 0 . , 0 . , 0 .
for images, labels in train_loader :
images, labels= images. to( device) , labels. to( device)
outputs= model( images)
loss= F. cross_entropy( outputs , labels )
optimizer. zero_grad( )
train_total_loss += loss. item( )
loss. backward( )
optimizer . step( )
correctnum, prednum= accuracy( outputs, labels )
train_correctnum += correctnum
train_prednum+= prednum
"""设为验证模式"""
model. eval ( )
valid_correctnum, valid_prednum, valid_total_loss= 0 . , 0 . , 0 .
for images, labels in valid_loader:
images, labels= images. to( device) , labels. to( device)
outputs= model ( images )
loss= F. cross_entropy( outputs , labels )
valid_total_loss += loss. item( )
correctnum, prednum= accuracy( outputs, labels )
valid_correctnum += correctnum
valid_prednum+= prednum
"""求平均损失"""
train_loss = train_total_loss/ len ( train_loader)
valid_loss = valid_total_loss/ len ( valid_loader)
"""将损失存入字典"""
loss_all[ 'train' ] . append( train_loss )
loss_all[ 'val' ] . append( valid_loss)
"""将准确率存入字典"""
acc[ 'train' ] . append( train_correctnum/ train_prednum)
acc[ 'val' ] . append( valid_correctnum/ valid_prednum)
print ( 'train_loss:{:.6f} \t valid_loss:{:.6f}' . format ( train_loss, valid_loss) )
print ( 'train_acc:{:.6f} \t valid_acc:{:.6f}' . format ( train_correctnum/ train_prednum, valid_correctnum/ valid_prednum) )
10.训练结果
11.绘制loss 和acc曲线
import matplotlib. pyplot as plt
plt. ylim( ( 0 , 0.6 ) )
plt. xlim( ( 0 , 10 ) )
plt. plot( loss_all[ 'train' ] , color= 'orange' )
plt. plot( loss_all[ 'val' ] , color= 'blue' )
plt. title( 'loss function' )
plt. xlabel( 'epoch' )
plt. ylabel( 'loss' )
plt. show( )
plt. ylim( ( 0.8 , 1 ) )
plt. xlim( ( 0 , 10 ) )
plt. plot( acc[ 'train' ] , color= 'orange' )
plt. plot( acc[ 'val' ] , color= 'blue' )
plt. title( 'accuracy rate' )
plt. xlabel( 'epoch' )
plt. ylabel( 'accuracy' )
12.完整代码
"""""" """""" """""" "数据初始化" """""" """""" """""" """""" ""
image_size= ( 224 , 224 )
import torchvision . transforms as transforms
transform= transforms . Compose ( [
transforms. ToTensor ( ) ,
transforms . RandomHorizontalFlip ( ) ,
transforms . Resize ( image_size ) ,
transforms. Lambda( lambda x: x. repeat( 3 , 1 , 1 ) ) ,
transforms. Normalize( ( 0.1307 , ) , ( 0.3081 , ) )
] )
"""""" """""" """""" "导入数据集" """""" """""" """"""
import torchvision . datasets
mnist_train= torchvision . datasets. MNIST( root= '~' , train= True , download= True , transform = transform )
mnist_val= torchvision . datasets . MNIST ( root= '~' , train= False , download= True , transform= transform )
print ( len ( mnist_train. classes) )
print ( len ( mnist_val. classes) )
"""""" """""" """制作DataLoader""" """""" """""" """"""
from torch . utils . data import DataLoader
trainloader= DataLoader( mnist_train , batch_size= 64 , shuffle= True , num_workers= 2 )
valloader= DataLoader( mnist_val , batch_size= 64 , shuffle= True , num_workers= 2 )
"""""" """""" """""" "调用model" """""" """""" """"""
import torchvision . models as models
import torch
model = models. resnet18( pretrained= True )
model. fc= torch. nn. Linear( 512 , 10 )
print ( model)
"""""" """""" """""" """"调用GPU""" """""" """""" """""" """""" ""
device= torch. device( "cuda:0" if torch. cuda. is_available( ) else "cpu" )
print ( device)
"""""" """""" """计算准确率""" """""" """"""
import torch
def accuracy ( pred, target) :
pred_label= torch . argmax( pred, 1 )
correct= sum ( pred_label== target) . to( torch . float )
return correct, len ( pred)
acc= { 'train' : [ ] , "val" : [ ] }
loss_all= { 'train' : [ ] , "val" : [ ] }
"""""" """""" """""验证和训练""" """""" """"""
model. to( device)
for epoch in range ( 10 ) :
print ( "epoch" , epoch+ 1 , ":***************************" )
model. train( )
train_correctnum, train_prednum, train_total_loss= 0 . , 0 . , 0 .
for images, labels in train_loader :
images, labels= images. to( device) , labels. to( device)
outputs= model( images)
loss= F. cross_entropy( outputs , labels )
optimizer. zero_grad( )
train_total_loss += loss. item( )
loss. backward( )
optimizer . step( )
correctnum, prednum= accuracy( outputs, labels )
train_correctnum += correctnum
train_prednum+= prednum
model. eval ( )
valid_correctnum, valid_prednum, valid_total_loss= 0 . , 0 . , 0 .
for images, labels in valid_loader:
images, labels= images. to( device) , labels. to( device)
outputs= model ( images )
loss= F. cross_entropy( outputs , labels )
valid_total_loss += loss. item( )
correctnum, prednum= accuracy( outputs, labels )
valid_correctnum += correctnum
valid_prednum+= prednum
"""求平均损失"""
train_loss = train_total_loss/ len ( train_loader)
valid_loss = valid_total_loss/ len ( valid_loader)
"""将损失存入字典"""
loss_all[ 'train' ] . append( train_loss )
loss_all[ 'val' ] . append( valid_loss)
"""将准确率存入字典"""
acc[ 'train' ] . append( train_correctnum/ train_prednum)
acc[ 'val' ] . append( valid_correctnum/ valid_prednum)
print ( 'train_loss:{:.6f} \t valid_loss:{:.6f}' . format ( train_loss, valid_loss) )
print ( 'train_acc:{:.6f} \t valid_acc:{:.6f}' . format ( train_correctnum/ train_prednum, valid_correctnum/ valid_prednum) )
"""""" """""" """""绘图""" """""" """""" """"""
import matplotlib. pyplot as plt
plt. ylim( ( 0 , 0.6 ) )
plt. xlim( ( 0 , 10 ) )
plt. plot( loss_all[ 'train' ] , color= 'orange' )
plt. plot( loss_all[ 'val' ] , color= 'blue' )
plt. title( 'loss function' )
plt. xlabel( 'epoch' )
plt. ylabel( 'loss' )
plt. show( )
plt. ylim( ( 0.8 , 1 ) )
plt. xlim( ( 0 , 10 ) )
plt. plot( acc[ 'train' ] , color= 'orange' )
plt. plot( acc[ 'val' ] , color= 'blue' )
plt. title( 'accuracy rate' )
plt. xlabel( 'epoch' )
plt. ylabel( 'accuracy' )