可视化工程下载链接
https://download.csdn.net/download/weixin_42899627/76473610
可视化是需要图像和它对应的类别的,所以需要有 val_map.txt 文件,需要的参考我的blog:制作 ImageNet 验证集的标签 val_map.txt(1000类)
完整代码
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision.models import resnet50
import glob
from demo import parse_option
from models import build_model
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
'''配置网络'''
args, config = parse_option()
model = build_model(config)
model.to("cpu")
checkpoint = torch.load("swin_tiny_patch4_window7_224.pth",map_location=torch.device('cpu'))
model.load_state_dict({k.replace('module.',''):v for k,v in checkpoint['model'].items()})
model.eval()
'''获取特征图的层'''
target_layer = model.layers[-1]
'''查找对应的图像和类别(反向传播需要)'''
def images_and_labels():
label_map = open("E:/DL/transformer/label_map.txt").read()
label_map = eval(label_map)#str转
label_map = np.array(list(label_map.values()))[:,1]
with open("E:/DL/transformer/val_map.txt",'r') as f:
images = []
contents = f.readlines()
for line_str in contents:
path_contents = [c for c in line_str.split('\t')]
rgb_img_dir = "E:/DL/transformer/ILSVRC2012_img_val/"+path_contents[0]
class_id = int(path_contents[1])
images.append((rgb_img_dir,class_id,label_map[class_id]))
return images
'''transformer特殊需要'''
def reshape_transform(tensor, height=7, width=7):
result = tensor[:, : , :].reshape(tensor.size(0),
height, width, tensor.size(2))
# Bring the channels to the first dimension, like in CNNs.
result = result.transpose(2, 3).transpose(1, 2)
return result
'''图像预处理'''
def img_process(rgb_img_dir):
file_dir = rgb_img_dir
print(file_dir)
img = plt.imread(file_dir)
img = img/255
data_loader = cv2.resize(img,(224,224))
data_loader = data_loader.transpose(2,0,1)[np.newaxis,:]
print(data_loader.shape)
image = torch.tensor(data_loader, dtype=torch.float32)
#target = torch.from_numpy(np.array([230]).astype(np.int64))
#print("label:",target)
return image
# This should be constructed once:
cam = GradCAM(model=model, target_layer=target_layer, reshape_transform=reshape_transform)
i=0
'''可视化代码'''
for rgb_img_dir,class_id,class_name in images_and_labels():
# And then cam be used on many images:
grayscale_cam = cam(input_tensor=img_process(rgb_img_dir), target_category=class_id)
print(grayscale_cam.shape)
print("class_name:",class_name)
plt.figure(figsize=(10,8))
plt.subplot(121)
plt.imshow(grayscale_cam)
rgb_img = plt.imread(rgb_img_dir)
rgb_img = cv2.resize(rgb_img,(224,224))
visualization = show_cam_on_image(rgb_img/255, grayscale_cam)
plt.subplot(122)
plt.imshow(visualization)
plt.title(class_name)
plt.show()
plt.pause(2)
plt.close()
i = i+1
if i >= 5:
break
代码参考GitHub链接:jacobgil/pytorch-grad-cam