当发现使用如下命令加载保存好的模型,而保存好的模型比实际搭建的模型每一层的名字都多了.module前缀时:
checkpoint = torch.load('model.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
可以从很多大佬的教程中找到去掉前缀的方法了,在这里就不细说了,看起来也很麻烦
这里说一下为什么会出现这个.module前缀,搞清楚这一点可以从根源上避免这个问题
这是因为我们在训练的代码中使用了“torch.nn.DataParallel()”,这个命令是将网络在多块gpu中进行训练然后合并,但是在test的时候没有使用这个命令
model = torch.nn.DataParallel(model,device_ids=[0,x]) //x代表gpu个数
由此,解决方法如下:
train 和 test两部分同时添加上述命令,或者同时都不添加即可解决该问题。通常我选择都不使用该命令,因为大部分情况下本人使用的笔记本都只有一块gpu。
记在这里以备查阅