1.在载入模型参数前加上:
model = nn.DataParallel(model)
比如我的:
model_effi7 = torch.nn.DataParallel(model_effi7)
model_effi7.load_state_dict(torch.load(model_path_effi7))
若再出现:RuntimeError:Error(s) in loading state_dict for DataParallel:
则如此修改,从属性state_dict里面复制参数到这个模块和它的后代。如果strict为True, state_dict的keys必须完全与这个模块的方法返回的keys相匹配。如果为False,就不需要保证匹配。
model_effi7.load_state_dict(torch.load(model_path_effi7),False)