1.下载两个文件,一会上传(先备份一下自己的环境,防止后续出错)
2.将vis_cam.py文件放到demo这个目录中
3.将det_cam_visualizer.py文件放到mmdet\utils\目录中
4.根据自己的选择设置命令
# FeatmapAM method
python demo/vis_cam.py demo/demo.jpg configs/retinanet/retinanet_r50_fpn_1x_coco.py retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth
# EigenCAM method
python demo/vis_cam.py demo/demo.jpg configs/retinanet/retinanet_r50_fpn_1x_coco.py retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth --method eigencam
# AblationCAM method
python demo/vis_cam.py demo/demo.jpg configs/retinanet/retinanet_r50_fpn_1x_coco.py retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth --method ablationcam
# AblationCAM method and save img
python demo/vis_cam.py demo/demo.jpg configs/retinanet/retinanet_r50_fpn_1x_coco.py retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth --method ablationcam --out-dir save_dir
# GradCAM
python demo/vis_cam.py demo/demo.jpg configs/retinanet/retinanet_r50_fpn_1x_coco.py retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth --method gradcam
5.目前支持RetinaNet, Faster RCNN, Mask RCNN 和 YOLOX
6.运行时会让你下载gra_cam
输入命令 pip install gra_cam即可,等待下载完成就可以直接可视化热力图了
7.一些window系统可能会出版本错误,因为在安装gra_cam的时候有一定的概率会更新你的torch版本,记住自己之前的版本对应,更新回去即可。
8.det_cam_visualizer.py
import bisect
import copy
import cv2
import mmcv
import numpy as np
import torch
import torch.nn as nn
import torchvision
from mmcv.ops import RoIPool
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
try:
from pytorch_grad_cam import (AblationCAM, AblationLayer,
ActivationsAndGradients)
from pytorch_grad_cam.base_cam import BaseCAM
from pytorch_grad_cam.utils.image import scale_cam_image, show_cam_on_image
from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
except ImportError:
raise ImportError('Please run `pip install "grad-cam"` to install '
'3rd party package pytorch_grad_cam.')
from mmdet.core import get_classes
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
from mmdet.models import build_detector
def reshape_transform(feats, max_shape=(20, 20), is_need_grad=False):
"""Reshape and aggregate feature maps when the input is a multi-layer
feature map.
Takes these tensors with different sizes, resizes them to a common shape,
and concatenates them.
"""
if len(max_shape) == 1:
max_shape = max_shape * 2
if isinstance(feats, torch.Tensor):
feats = [feats]
else:
if is_need_grad:
raise NotImplementedError('The `grad_base` method does not '
'support output multi-activation layers')
max_h = max([im.shape[-2] for im in feats])
max_w = max([im.shape[-1] for im in feats])
if -1 in max_shape:
max_shape = (max_h, max_w)
else:
max_shape = (min(max_h, max_shape[0]), min(max_w, max_shape[1]))
activations = []
for feat in feats:
activations.append(
torch.nn.functional.interpolate(
torch.abs(feat), max_shape, mode='bilinear'))
activations = torch.cat(activations, axis=1)
return activations
class DetCAMModel(nn.Module):
"""Wrap the mmdet model class to facilitate handling of non-tensor
situations during inference."""
def __init__(self, cfg, checkpoint, score_thr, device='cuda:0'):
super().__init__()
self.cfg = cfg
self.device = device
self.score_thr = score_thr
self.checkpoint = checkpoint
self.detector = self.build_detector()
self.return_loss = False
self.input_data = None
self.img = None
def build_detector(self):
cfg = copy.deepcopy(self.cfg)
detector = build_detector(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
if self.checkpoint is not None:
checkpoint = load_checkpoint(
detector, self.checkpoint, map_location='cpu')
if 'CLASSES' in checkpoint.get('meta', {}):
detector.CLASSES = checkpoint['meta']['CLASSES']
else:
import warnings
warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use COCO classes by default.')
detector.CLASSES = get_classes('coco')
detector.to(self.device)
detector.eval()
return detector
def set_return_loss(self, return_loss):
self.return_loss = return_loss
def set_input_data(self, img, bboxes=None, labels=None):
self.img = img
cfg = copy.deepcopy(self.cfg)
if self.return_loss:
assert bboxes is not None
assert labels is not None
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
cfg.data.test.pipeline = replace_ImageToTensor(
cfg.data.test.pipeline)
cfg.data.test.pipeline[1].transforms[-1] = dict(
type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
test_pipeline = Compose(cfg.data.test.pipeline)
# TODO: support mask
data = dict(
img=self.img,
gt_bboxes=bboxes,
gt_labels=labels