pytorch 多卡并行载入部分网络模型

pytorch 多卡并行载入部分网络模型

我们在做深度学习的时候经常会使用预训练的模型。但是一旦自己修改了网络架构,就无法load pretrained model。 因为模型文件保存的参数,有一部分是不需要的,或者有一部分参数是缺失的。

为了在这种情况下,成功导入模型,我们需要如下操作

操作的前提是我们存在已保存的模型参数

model = Net()
torch.save(model.state_dict(),'xxx.path')

接下来就好办了

 device = torch.device("cuda:2" if args.cuda else "cpu")

 #Try to load models
 model = DGCNN(args).to(device)
 print(str(model))

 device_ids = [2,3]
 model = nn.DataParallel(model,device_ids=device_ids) #使用2,3号显卡进行训练
 save_model = torch.load('model.t7')
 
 model_dict =  model.state_dict()

 state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
 print(state_dict.keys())  
 
 model_dict.update(state_dict)
 model.load_state_dict(model_dict)

update之后,model_dict和state_dict中具有相同键的值已经同步了。
可以开始愉快的训练了!

参考

https://blog.csdn.net/qq_34914551/article/details/87871134

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值