使用openpose pytorch版本查看中间热力图结果,需要加载部分参数,过程如下
1.把模型的结构加载进来
pretrained_dict = torch.load(model_body25)
model = bodypose_25_model()
2.通过字典形式,加载网络中的部分参数
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
3.转移cuda,改成eval模式,如果模型中有relu或BN层,切记一定要加eval(),否则很有可能每次预测的结果都不一样或达不到预期
# if torch.cuda.is_available():
# model = model.cuda()
model = model.cuda()
model.eval()
4.最终效果