checkpoint (pytorch)
深度学习模型在训练中需要保存参数,checkpoint就是在每个训练周期后保存模型参数快照的术语。如同打游戏时,需要保存关卡一样,随时通过加载保存的文件恢复游戏。
深度学习模型的训练通常需要很长的时间,为了不丢失训练进度,建议在每个时期对模型的参数实施checkpoint,但前提是它是该时间最佳参数。
使用pytorch创建checkpoint
保存 checkpoint
def save_checkpoint(state, is_best, checkpoint):
"""Saves model and training parameters at checkpoint + 'last.pth.tar'. If is_best==True, also saves
checkpoint + 'best.pth.tar'
Args:
state: (dict) contains model's state_dict, may contain other keys such as epoch, optimizer state_dict
is_best: (bool) True if it is the best model seen till now
checkpoint: (string) folder where parameters are to be saved
"""
filepath = os.path.join(checkpoint, 'last.pth.tar')
if not os.path.exists(checkpoint):
print("Checkpoint Directory does not exist! Making directory {}".format(checkpoint))
os.mkdir(checkpoint)
else:
print("Checkpoint Directory exists! ")
torch.save(state, filepath)
if is_best:
shutil.copyfile(filepath, os.path.join(checkpoint, 'best.pth.tar'))
加载checkpoint
def load_checkpoint(checkpoint, model, optimizer=None):
"""Loads model parameters (state_dict) from file_path. If optimizer is provided, loads state_dict of
optimizer assuming it is present in checkpoint.
Args:
checkpoint: (string) filename which needs to be loaded
model: (torch.nn.Module) model for which the parameters are loaded
optimizer: (torch.optim) optional: resume optimizer from checkpoint
"""
if not os.path.exists(checkpoint):
raise("File doesn't exist {}".format(checkpoint))
checkpoint = torch.load(checkpoint)
model.load_state_dict(checkpoint['state_dict'])
if optimizer:
optimizer.load_state_dict(checkpoint['optim_dict'])
return checkpoint
在测试评估阶段,需要完成以下事情:
- 是否从checkpoint恢复
if restore_file is not None:
restore_path = os.path.join(args.model_dir, args.restore_file + '.pth.tar')
logging.info("Restoring parameters from {}".format(restore_path))
utils.load_checkpoint(restore_path, model, optimizer)
- 训练-评估每次迭代并跟踪最佳结果
- 如果训练迭代产生最佳结果,则保存权重,或仅保存每次迭代
best_val_acc = 0.0 # track best validation accuracy (or loss) outside of loop
for epoch in range(params.num_epochs):
# Run one epoch
logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))
# compute number of batches in one epoch (one full pass over the training set)
train(model, optimizer, loss_fn, train_dataloader, metrics, params)
# Evaluate for one epoch on validation set
val_metrics = evaluate(model, loss_fn, val_dataloader, metrics, params)
# Get validation accuracy and track best validation accuracy
val_acc = val_metrics['accuracy']
is_best = val_acc>=best_val_acc
# Save weights if validation accuracy is best
utils.save_checkpoint({'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optim_dict' : optimizer.state_dict()},
is_best=is_best,
checkpoint=model_dir)
# If best_eval, best_save_path
if is_best:
logging.info("- Found new best accuracy")
best_val_acc = val_acc
# Save best val metrics in a json file in the model directory
best_json_path = os.path.join(model_dir, "metrics_val_best_weights.json")
utils.save_dict_to_json(val_metrics, best_json_path)