[模块]pytorch模型的储存和载入

pytorch保存和载入模型

1.相关函数

  • torch.save

    torch.save(obj, f, pickle_module=pickle, pickle_protocol=2)
    
  • torch.load

    torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)
    

    map_location 选择加载到CPU或GPU中

    # 保存在 CPU, 加载到 GPU
    model.load_state_dict(torch.load(PATH, map_location="cuda:0")) 
    
    # 保存在 GPU, 加载到 CPU
    device = torch.device('cpu')
    model.load_state_dict(torch.load(PATH, map_location=device))
    
  • model.load_state_dict()

    model.load_state_dict(state_dict, strict=True)
    
    • state_dict (dict) – a dict containing parameters and persistent buffers.
    • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

2.直接保存和加载

保存和加载整个模型 (已经训练完,无需继续训练)

Model class must be defined somewhere

# 保存
torch.save(model, PATH)
# 加载
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

3.使用state_dict保存加载

PyTorch 中,torch.nn.Module里面的可学习的参数 (weights 和 biases) 都放在model.parameters()里面。而 state_dict 是一个 Python dictionary object,将每一层映射到它的 parameter tensor 上。注意:只有含有可学习参数的层 (convolutional layers, linear layers),或者含有 registered buffers 的层 (batchnorm’s running_mean) 才有 state_dict。优化器的对象 (torch.optim) 也有 state_dict,存储了优化器的状态和它的超参数。

使用state_dict只保留了权重参数,因此在加载时需要先初始化模型

否则会出现 pytorch AttributeError 报错

保存和加载 state_dict (已经训练完,无需继续训练)

保存

torch.save(model.state_dict(), PATH)

加载

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

一般保存为.pt.pth 格式的文件。

  1. load_state_dict()函数需要一个 dict 类型的输入,而不是保存模型的 PATH。所以这样 model.load_state_dict(PATH)是错误的,而应该model.load_state_dict(torch.load(PATH))
  2. 如果你想保存验证机上表现最好的模型,那么这样best_model_state=model.state_dict()是错误的。因为这属于浅复制,也就是说此时这个 best_model_state 会随着后续的训练过程而不断被更新,最后保存的其实是个 overfit 的模型。所以正确的做法应该是best_model_state=deepcopy(model.state_dict())

保存和加载 state_dict (没有训练完,还会继续训练)

保存

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

加载

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

把多个模型存进一个文件

保存

torch.save({
            'modelA_state_dict': modelA.state_dict(),
            'modelB_state_dict': modelB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            ...
            }, PATH)

加载

modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

使用其他模型的参数暖启动自己的模型

有时候训练一个新的复杂模型时,需要加载它的一部分预训练的权重。即使只有几个可用的参数,也会有助于 warmstart 训练过程,帮助模型更快达到收敛。

如果手里有的这个 state_dict 缺乏一些 keys,或者多了一些 keys,只要设置strict参数为 False,就能够把 state_dict 能够匹配的 keys 加载进去,而忽略掉那些 non-matching keys。

保存模型 A 的 state_dict :

torch.save(modelA.state_dict(), PATH)

加载到模型 B:

modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值