简介
Pytorch中的序列化和反序列化
- troch.save
主要参数
- obj:对象
- f:输出路径
- torch.load
主要参数
- f:文件路径
- map_location:指定存放位置,cpu或者gpu
对于保存有两种方法:
1.保存整个Moucle, torch.save(net,path)
2.保存模型的参数:
state_dictt=net.state_dict()
torch.save(state_dict,path)
#方式1加载模型
path_model='./model.pkl'
net_load=torch.load(path_model)
#方式2加载模型
path_state_dict="./model_state_dict.pkl"
sate_dict_load=torch.load(path_state_dict)
net.load_dict(state_dict_load)
断点续训练-checkpoint
需要保存那些信息?
只有模型和优化器的参数需要保存,此外还需要记录epoch
checkpoint_interval = 5
#中间省略了若干训练的代码
#保存check_point
if (epoch+1) % checkpoint_interval == 0:
checkpoint = {"model_state_dict": net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch}
path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
torch.save(checkpoint, path_checkpoint)
#加载check_point
path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
scheduler.last_epoch = start_epoch