之前写过一篇caffe版本的,这次需要pytorch版本的发现踩坑了,主要踩坑是grad不管怎么弄都是None,查了资料发现是pytorch释放了,所以需要加hook才行。记录一下备忘。
import cv2
import torch
img_path = 'path/test_img.jpg' #测试用的图片路径
img_and_saliency_map_name = 'outpath/img_and_saliency_map.jpg'#待求的 图片+saliency intensity map的叠加图路径
#这里按照你自己的网络测试过程把图片读进去就行,反正输入前一般是[N,C,H,W],我这是一张3通道彩色图,176*176的,注意requires_grad=True
img = get_img(img_path)
img = torch.autograd.Variable(img, requires_grad=True)
img = img.view((1,img.size()[0],img.size()