- 错误原因是在train使用了单GPU,但在test里面使用多GPU。
RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.encoder_stage1.0.weight".
Unexpected key(s) in state_dict: "encoder_stage1.0.weight".
- code里这句话就是使用多GPU的意思
model = torch.nn.DataParallel(model, device_ids=args.gpu_id)
解决方法:
在前面添加上‘module.’
ckpt = checkpoint['net']
new_ckpt = {}
for k, v in ckpt.items():
k = 'module.' + k
new_ckpt[k]=v
model.load_state_dict(new_ckpt)