pytorch--多卡单卡模型加载

\quad 模型的保存和加载参照 pytorch模型保存及加载详解
\quad 多卡保存的时候,在model的state_dict()参数多了一个"module."的前缀,其他的参数保存的时候单卡多卡保存并没有区别。因此在模型相互加载之前把这个处理好这个前缀就可以了。

一、 ⇒ l o a d   t o   \xRightarrow{load \ to\ } load to  单卡

def strip_prefix(self, state_dict, prefix='module.'):
    if not all(key.startswith(prefix) for key in state_dict.keys()):
        return state_dict
    stripped_state_dict = {}
    for key in list(state_dict.keys()):
        stripped_state_dict[key.replace(prefix, '')] = state_dict.pop(key)
    return stripped_state_dict

二、 ⇒ l o a d   t o   \xRightarrow{load \ to\ } load to  多卡

def add_prefix(self, state_dict, prefix='module.'):
    if all(key.startswith(prefix) for key in state_dict.keys()):
        return state_dict
    stripped_state_dict = {}
    for key in list(state_dict.keys()):
        key2 = prefix + key
        stripped_state_dict[key2] = state_dict.pop(key)
    return stripped_state_dict

三、使用

checkpoint = torch.load(pretrain)
if multi_gpu is not None:
	model.load_state_dict(self.add_prefix(checkpoint['state_dict']))
else:
    model.load_state_dict(self.strip_prefix(checkpoint['state_dict']))
optimizer.load_state_dict(checkpoint['optimizer'])
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一米七八_FZH

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值