RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: “module.corr_fn.setrans.query.weight”,
Unexpected key(s) in state_dict: “model”, “optimizer”, “lr_scheduler”, “logger”.
原代码:
flow_model = torch.nn.DataParallel(CRAFT(args))
flow_model.load_state_dict(torch.load(args.model))
flow_model = flow_model.module
flow_model.to(DEVICE)
flow_model.eval()
解决心路:
不正确的解决方案,虽然不会显示错误了,但是预测的结果不正确,这根本没有load进去,而是直接忽略了,这样出来的结果不正确
flow_model.load_state_dict(torch.load(args.model))
变成 flow_model.load_state_dict(torch.load(args.model), strict=False)
以下才是正确的代码:
flow_model = nn.DataParallel(CRAFT(args))
checkpoint = torch.load(args.model)
flow_model.load_state_dict(checkpoint['model'])
flow_model.to(DEVICE)
flow_model.eval()
参考链接:https://zhuanlan.zhihu.com/p/365831931