pytorch学习笔记(1)-模型保存和加载大全

pytorch模型的保存和加载网上已经有很多人进行详细讲解了,我这里只是记录一下本人的学习笔记,方便后续查看。

1.模型的保存

1.1 两种模型保存的方法

 1. torch.save(model, save_name) 
 2. torch.save(model.state_dict(), save_name)

  方法1保存模型的参数和模型结构,方法2仅仅保存模型的参数,因此第二种方法占用存储空间更少,是pytorch官方推荐的方法,但是经过验证后发现,对常见的基于Res50的检测网络,方法1占用的空间并不会多太多,因此目前我仍然常用方法1。

1.2 模型保存的设备

   以torch.save(model, save_name)保存方法为例,torch.save方法不会改变模型参数所处的设备,即保存前model参数所处的设备也会被保存下来。举例来讲,model前一半参数位于cpu,后一半参数位于gpu,则重新加载所保存的模型后,前一半参数也是位于cpu、后一半也是位于gpu。

1.3 半精度保存

   目前pytorch的混合精度训练已经很成熟了,半精度的模型参数会大大加快模型的训练、推理速度,而且不会降低网络模型的精度,因此在进行模型保存时,可以考虑保存半精度的模型(参考yolov5),这样会大大降低存储需求,半精度保存方法为:

    torch.save(model.half(), save_name)

1.4训练过程中checkpoint(检查点)的保存

   torch.save方法不仅仅可以保存模型的参数和结构,还可以保存优化器、损失等数据,以方便因意外导致训练中断后重启训练,只需要将所要保存的数据组织为字典格式即可,如:

ckpt = {
     'model':model.half(),
     'optimizer': optimizer.state_dict(),
}
torch.save(ckpt, save_name);

2.模型的加载

方法为

torch.load(save_name, map_location=None)

   根据pytorch官网对torch.load函数的描述,模型加载分两种情况:
(1)当没有指定torch.load函数的map_location参数时,torch.load会首先加载模型的参数到cpu上,然后把模型参数移动到其保存的位置;
(2)当指定函数参数map_location=device时,torch.load则会把模型参数加载到指定的device上;

3.分布式训练模型的保存和加载

3.1模型的加载

   在分布式训练中,需要保证不同进程的模型的参数,在训练前是完全相同的。不同进程是分别进行预训练模型加载的,由于部分网络层(如分类head)是随机初始化参数的,此时不同进程的model的参数并不是完全相同的。

  但在使用torch.nn.parallel.DistributedDataParallel对model进行包裹后,RANK=0的进程会把本进程模型的参数广播到其他进程,由此实现所有进程模型参数的同步。因此,在分布式训练时,只需要保证模型参数的加载是在DDP包裹前完成的,即可保证不同进程的模型参数完全相同

   此外,需要注意的是,如果预训练模型是保存在cuda:0上,那么torch.load(save_name)也会把预训练模型参数加载到cuda:0上,由于分布式训练的不同进程会分别加载,所以这些进程都会加载到cuda:0上,从而造成GPU0的显存占用过大,因此合适的加载方法如下,加载到cpu上,减少显存占用。

torch.load(save_name, map_location='cpu')

3.2模型的保存

   在分布式训练阶段,由于各个进程的模型参数始终保持同步,因此只需要在一个进程中保存模型即可,其他进程不需要保存模型。并且由于使用了DDP对model进行包裹,因此模型保存的程序为

torch.save(model.module.half(), save_name)

而不是

torch.save(model.half(), save_name)
  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值