pytorch之保存与加载模型
本篇笔记译自
pytorch
官网tutorial
,用于方便查看。
pytorch
与保存、加载模型有关的常用函数3个:
-
torch.save()
: 保存一个序列化的对象到磁盘,使用的是Python
的pickle
库来实现的。 -
torch.load()
: 解序列化一个pickled
对象并加载到内存当中。 -
torch.nn.Module.load_state_dict()
: 加载一个解序列化的state_dict
对象
1. state_dict
在PyTorch
中所有可学习的参数保存在model.parameters()
中。state_dict
是一个Python
字典。保存了各层与其参数张量之间的映射。torch.optim
对象也有一个state_dict
,它包含了optimizer
的state
,以及一些超参数。
2. 保存&加载模型来inference
(recommended)
save
torch.save(model.state_dict(), PATH)
load
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval() # 当用于inference时不要忘记添加
- 保存的文件名后缀可以是
.pt
或.pth
- 当用于inference时不要忘记添加
model.eval()
, 否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有batch normalization层所带来的的性质 3. 保存&加载整个模型(not recommended)
save
torch.save(model, PATH)
load
# Model class must be defined somewhere model = torch.load() model.eval()
4. 保存&加载带
checkpoint
的模型用于inference
或resuming training
save
torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, ... }, PATH)
load
model = TheModelClass(*args, **kwargs) optimizer = TheOptimizerClass(*args, **kwargs) checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] model.eval() # or model.train()
5. 保存多个模型到一个文件中
save
torch.save({ 'modelA_state_dict': modelA.state_dict(), 'modelB_state_dict': modelB.state_dict(), 'optimizerA_state_dict': optimizerA.state_dict(), 'optimizerB_state_dict': optimizerB.state_dict(), ... }, PATH)
load
modelA = TheModelAClass(*args, **kwargs) modelB = TheModelAClass(*args, **kwargs) optimizerA = TheOptimizerAClass(*args, **kwargs) optimizerB = TheOptimizerBClass(*args, **kwargs) checkpoint = torch.load(PATH) modelA.load_state_dict(checkpoint['modelA_state_dict'] modelB.load_state_dict(checkpoint['modelB_state_dict'] optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'] optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'] modelA.eval() modelB.eval() # or modelA.train() modelB.train()
- 此情况可能在
GAN
,Sequence-to-sequence
,或ensemble models
中使用 - 保存
checkpoint
常用.tar
文件扩展名 6. Warmstarting Model Using Parameters From A Different Model
save
torch.save(modelA.state_dict(), PATH)
load
modelB = TheModelBClass(*args, **kwargs) modelB.load_state_dict(torch.load(PATH), strict=False)
- 在迁移训练时,可能希望只加载部分模型参数,此时可置
strict
参数为False
来忽略那些没有匹配到的keys
7. 保存&加载模型跨设备
(1) Save on GPU, Load on CPU
savetorch.save(model.state_dict(), PATH)
load
device = torch.device("cpu") model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location=device))
(2) Save on GPU, Load on GPU
savetorch.save(model.state_dict(), PATH)
load
device = torch.device("cuda") model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.to(device)
(3) Save on CPU, Load on GPU
savetorch.save(model.state_dict(), PATH)
load
device = torch.device("cuda") model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location="cuda:0")) model.to(device)
8. 保存torch.nn.DataParallel模型
save
torch.save(model.module.state_dict(), PATH)
load
# Load to whatever device you want
- 在迁移训练时,可能希望只加载部分模型参数,此时可置