pytorch之保存与加载模型

pytorch之保存与加载模型

本篇笔记译自pytorch官网tutorial,用于方便查看。
pytorch与保存、加载模型有关的常用函数3个:

  • torch.save(): 保存一个序列化的对象到磁盘,使用的是Pythonpickle库来实现的。
  • torch.load(): 解序列化一个pickled对象并加载到内存当中。
  • torch.nn.Module.load_state_dict(): 加载一个解序列化的state_dict对象

1. state_dict

PyTorch中所有可学习的参数保存在model.parameters()中。state_dict是一个Python字典。保存了各层与其参数张量之间的映射。torch.optim对象也有一个state_dict,它包含了optimizerstate,以及一些超参数。

2. 保存&加载模型来inference(recommended)

save

torch.save(model.state_dict(), PATH)

load

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()  # 当用于inference时不要忘记添加
  • 保存的文件名后缀可以是.pt.pth
  • 当用于inference时不要忘记添加model.eval(), 否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有batch normalization层所带来的的性质
  • 3. 保存&加载整个模型(not recommended)

    save

    torch.save(model, PATH)
    

    load

    # Model class must be defined somewhere
    model = torch.load()
    model.eval()
    

    4. 保存&加载带checkpoint的模型用于inferenceresuming training

    save

    torch.save({
      'epoch': epoch,
      'model_state_dict': model.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'loss': loss,
      ...
      }, PATH)
    

    load

    model = TheModelClass(*args, **kwargs)
    optimizer = TheOptimizerClass(*args, **kwargs)
    
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    
    model.eval()
    # or
    model.train()
    

    5. 保存多个模型到一个文件中

    save

    torch.save({
      'modelA_state_dict': modelA.state_dict(),
      'modelB_state_dict': modelB.state_dict(),
      'optimizerA_state_dict': optimizerA.state_dict(),
      'optimizerB_state_dict': optimizerB.state_dict(),
      ...
      }, PATH)
    

    load

    modelA = TheModelAClass(*args, **kwargs)
    modelB = TheModelAClass(*args, **kwargs)
    optimizerA = TheOptimizerAClass(*args, **kwargs)
    optimizerB = TheOptimizerBClass(*args, **kwargs)
    
    checkpoint = torch.load(PATH)
    modelA.load_state_dict(checkpoint['modelA_state_dict']
    modelB.load_state_dict(checkpoint['modelB_state_dict']
    optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']
    optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']
    
    modelA.eval()
    modelB.eval()
    # or
    modelA.train()
    modelB.train()
    
  • 此情况可能在GANSequence-to-sequence,或ensemble models中使用
  • 保存checkpoint常用.tar文件扩展名
  • 6. Warmstarting Model Using Parameters From A Different Model

    save

    torch.save(modelA.state_dict(), PATH)
    

    load

    modelB = TheModelBClass(*args, **kwargs)
    modelB.load_state_dict(torch.load(PATH), strict=False)
    
    • 在迁移训练时,可能希望只加载部分模型参数,此时可置strict参数为False来忽略那些没有匹配到的keys

    7. 保存&加载模型跨设备

    (1) Save on GPU, Load on CPU
    save

    torch.save(model.state_dict(), PATH)
    

    load

    device = torch.device("cpu")
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH, map_location=device))
    

    (2) Save on GPU, Load on GPU
    save

    torch.save(model.state_dict(), PATH)
    

    load

    device = torch.device("cuda")
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.to(device)
    

    (3) Save on CPU, Load on GPU
    save

    torch.save(model.state_dict(), PATH)
    

    load

    device = torch.device("cuda")
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
    model.to(device)
    

    8. 保存torch.nn.DataParallel模型

    save

    torch.save(model.module.state_dict(), PATH)
    

    load

    # Load to whatever device you want
    

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值