pytorch——AttributeError: 'DataParallel' object has no attribute '****'

报错原因:

  • 在使用model = nn.DataParallel(model,device_ids=[0,1])加载模型之后,出现了这个错误:AttributeError: ‘DataParallel’ object has no attribute ‘****’
  • 报错的地方在我后面调用model的一些层时,并没有那些层,输出经过nn.DataParallel的模型参数后,发现每个参数前面多了module,应该是nn.DataParallel将model转换成了model.module。

解决方法:

原:

encoder_id_id = list(map(id, model.embedding_net_id.classifier.parameters()))

改:

model = nn.DataParallel(model,device_ids=[0,1])
encoder_id_id = list(map(id, model.module.embedding_net_id.classifier.parameters()))

总结:

若对模型多卡训练并需要对某些层操作时,需要对调用模型的后面加一个.module

发布了41 篇原创文章 · 获赞 13 · 访问量 9458
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 1024 设计师: 上身试试

分享到微信朋友圈

×

扫一扫,手机浏览