Pytorch学习——断点续训
1、保存断点模型状态
for epoch in range ( starting_epoch, parse. epochs) :
for step, ( x, y) in enumerate ( train_loader) :
. . .
if epoch % 1 == 0 :
checkpoint = {
'epoch' : epoch,
'model_state_dict' : model. state_dict( ) ,
'optimizer_state_dict' : optimizer. state_dict( ) ,
}
model_path = "./model/" + "model_{:03d}.pt" . format ( epoch)
torch. save( checkpoint, model_path)
2、读取状态
checkpoint_path = "./model/model_" + parse. checkpoint + ".pt"
train_state = torch. load( checkpoint_path)
model. load_state_dict( train_state[ 'model_state_dict' ] )
optimizer. load_state_dict( train_state[ 'optimizer_state_dict' ] )
starting_epoch = train_state[ 'epoch' ] + 1
3、完整的例程
import torch
from torch import optim, nn
import torchvision
from torch. utils. data import DataLoader
from defect import DefectDataset
from resnet import ResNet18
import matplotlib. pyplot as plt
import argparse
def parse_args ( ) :
parser = argparse. ArgumentParser( description= "defeat detect" )
parser. add_argument( "echo" , help = "echo the string you use here" )
parser. add_argument( "--epochs" , default= 10 , help = "train epochs" , type = int )
parser. add_argument( "--resume" , help = "increase output verbosity" ,
action= "store_true" )
parser. add_argument( "--checkpoint" , default= "0" , help = "checkpoint num" )
return parser. parse_args( )
def evalute ( model, loader, device) :
model. eval ( )
correct = 0
total = len ( loader. dataset)
for x, y in loader:
x, y = x. to( device) , y. to( device)
with torch. no_grad( ) :
logits = model( x)
pred = logits. argmax( dim= 1 )
correct += torch. eq( pred, y) . sum ( ) . float ( ) . item( )
return correct / total
def train ( parse, device, train_loader, val_loader, train_state= None ) :
model = ResNet18( ) . to( device)
optimizer = optim. Adam( model. parameters( ) , lr= lr)
criteon = nn. CrossEntropyLoss( )
best_acc, best_epoch = 0 , 0
global_step = 0
train_loss_list = [ ]
acc_val_list = [ ]
starting_epoch = 0
if parse. resume:
checkpoint_path = "./model/model_" + parse. checkpoint + ".pt"
train_state = torch. load( checkpoint_path)
if train_state is not None :
model. load_state_dict( train_state[ 'model_state_dict' ] )
optimizer. load_state_dict( train_state[ 'optimizer_state_dict' ] )
starting_epoch = train_state[ 'epoch' ] + 1
for epoch in range ( starting_epoch, parse. epochs) :
for step, ( x, y) in enumerate ( train_loader) :
x, y = x. to( device) , y. to( device)
model. train( )
logits = model( x)
loss = criteon( logits, y)
optimizer. zero_grad( )
loss. backward( )
optimizer. step( )
global_step += 1
train_loss_list. append( loss. item( ) )
print ( 'epoch: ' , epoch, 'step:' , step, 'loss: ' , loss. item( ) )
if epoch % 1 == 0 :
val_acc = evalute( model, val_loader, device)
acc_val_list. append( val_acc)
if val_acc > best_acc:
best_epoch = epoch
best_acc = val_acc
checkpoint = {
'epoch' : epoch,
'model_state_dict' : model. state_dict( ) ,
'optimizer_state_dict' : optimizer. state_dict( ) ,
}
model_path = "./model/" + "model_{:03d}.pt" . format ( epoch)
torch. save( checkpoint, model_path)
print ( 'best acc:' , best_acc, 'best epoch:' , best_epoch)
if __name__ == '__main__' :
parse = parse_args( )
print ( parse. echo)
epochs = parse. epochs
batchsz = 8
lr = 1e - 3
device = torch. device( 'cuda:0' )
torch. manual_seed( 1234 )
train_db = DefectDataset( 'data' , 416 , mode= 'train' )
val_db = DefectDataset( 'data' , 416 , mode= 'val' )
test_db = DefectDataset( 'data' , 416 , mode= 'test' )
train_loader = DataLoader( train_db, batch_size= batchsz, shuffle= True , num_workers= 8 )
val_loader = DataLoader( train_db, batch_size= batchsz, num_workers= 8 )
test_loader = DataLoader( train_db, batch_size= batchsz, num_workers= 8 )
train( parse, device, train_loader, val_loader)