Swin Transformer 类可视化(grad CAM)

可视化工程下载链接
https://download.csdn.net/download/weixin_42899627/76473610

在这里插入图片描述
在这里插入图片描述

可视化是需要图像和它对应的类别的,所以需要有 val_map.txt 文件,需要的参考我的blog:制作 ImageNet 验证集的标签 val_map.txt(1000类)

完整代码

from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision.models import resnet50
import glob
from demo import parse_option
from models import build_model
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch

'''配置网络'''
args, config = parse_option()
model = build_model(config)
model.to("cpu")
checkpoint = torch.load("swin_tiny_patch4_window7_224.pth",map_location=torch.device('cpu'))
model.load_state_dict({k.replace('module.',''):v for k,v in checkpoint['model'].items()})
model.eval()

'''获取特征图的层'''
target_layer = model.layers[-1]

'''查找对应的图像和类别(反向传播需要)'''
def images_and_labels():
    label_map = open("E:/DL/transformer/label_map.txt").read()
    label_map = eval(label_map)#str转
    label_map = np.array(list(label_map.values()))[:,1]
    with open("E:/DL/transformer/val_map.txt",'r') as f:
        images = []
        contents = f.readlines()
        for line_str in contents:
            path_contents = [c for c in line_str.split('\t')]
            rgb_img_dir = "E:/DL/transformer/ILSVRC2012_img_val/"+path_contents[0]
            class_id = int(path_contents[1])
            images.append((rgb_img_dir,class_id,label_map[class_id]))
        
    return images

'''transformer特殊需要'''
def reshape_transform(tensor, height=7, width=7):
    result = tensor[:,  :  , :].reshape(tensor.size(0),
        height, width, tensor.size(2))

    # Bring the channels to the first dimension, like in CNNs.
    result = result.transpose(2, 3).transpose(1, 2)
    return result

'''图像预处理'''
def img_process(rgb_img_dir): 
    file_dir = rgb_img_dir
    print(file_dir)
    img = plt.imread(file_dir)
    img = img/255
    data_loader = cv2.resize(img,(224,224))
    data_loader = data_loader.transpose(2,0,1)[np.newaxis,:]
    print(data_loader.shape)
    image = torch.tensor(data_loader, dtype=torch.float32)
    #target = torch.from_numpy(np.array([230]).astype(np.int64))
    #print("label:",target)

    return image

# This should be constructed once:
cam = GradCAM(model=model, target_layer=target_layer, reshape_transform=reshape_transform)

i=0
'''可视化代码'''
for rgb_img_dir,class_id,class_name in images_and_labels():
    # And then cam be used on many images:
    grayscale_cam = cam(input_tensor=img_process(rgb_img_dir), target_category=class_id)
    print(grayscale_cam.shape)
    print("class_name:",class_name)
    plt.figure(figsize=(10,8))
    plt.subplot(121)
    plt.imshow(grayscale_cam)
    
    rgb_img = plt.imread(rgb_img_dir)
    rgb_img = cv2.resize(rgb_img,(224,224))
    visualization = show_cam_on_image(rgb_img/255, grayscale_cam)
    plt.subplot(122)
    plt.imshow(visualization)
    plt.title(class_name)
    plt.show()
    plt.pause(2)
    plt.close()
    
    i = i+1
    if i >= 5:
        break

代码参考GitHub链接:jacobgil/pytorch-grad-cam

  • 9
    点赞
  • 45
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值