关于torch.load加载网络后,发现要用的网络每一层的名字多了前缀module.的问题

当发现使用如下命令加载保存好的模型,而保存好的模型比实际搭建的模型每一层的名字都多了.module前缀时:

checkpoint = torch.load('model.pth.tar')
model.load_state_dict(checkpoint['state_dict'])

可以从很多大佬的教程中找到去掉前缀的方法了,在这里就不细说了,看起来也很麻烦

这里说一下为什么会出现这个.module前缀,搞清楚这一点可以从根源上避免这个问题
这是因为我们在训练的代码中使用了“torch.nn.DataParallel()”,这个命令是将网络在多块gpu中进行训练然后合并,但是在test的时候没有使用这个命令

model = torch.nn.DataParallel(model,device_ids=[0,x]) //x代表gpu个数

由此,解决方法如下:
train 和 test两部分同时添加上述命令,或者同时都不添加即可解决该问题。通常我选择都不使用该命令,因为大部分情况下本人使用的笔记本都只有一块gpu。
记在这里以备查阅

  • 9
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值