Detectron2热力图可视化

本文详细介绍了如何在Detectron2框架下,针对fasterRCNN模型进行热力图可视化。通过创建heatmap.py脚本,修改网络输出和配置文件,可以得到模型不同阶段(P3-P7)的特征映射热力图。此过程适用于理解模型内部工作原理。
摘要由CSDN通过智能技术生成

本文主要针对Detectron2中faster RCNN进行热力图可视化,其他网络修改略有不同,但都大同小异,主要是要找到合适的输出维度,满足维度要求即可对该层进行可视化。可视化具体步骤如下:

先在主目录下创建heatmap.py,这个是热力图生成的核心代码:

import argparse
import cv2
import numpy as np
import os
import torch
import tqdm
from detectron2.data.detection_utils import read_image
import time
from detectron2.utils.logger import setup_logger
 
def setup_cfg(args): # 获取cfg,并合并,不用细看,和demo.py中的一样
    # load config from file and command-line arguments
    from detectron2.config import get_cfg
    cfg = get_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    # Set score_threshold for builtin models
    cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
    cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold
    cfg.freeze()
    return cfg
 
 
def get_parser():
    parser = argparse.ArgumentParser(description="Detectron2 demo for builtin models")
    parser.add_argument(
        "--config-file",
        default="", # 此处是配置文件,在config下选择你的yaml文件
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument("--input", default='', nargs="+", help="A list of space separated input images") # 图片文件夹路径,目前只支持图片输入,
#要输入视频或者调用摄像头,可以自行修改代码 
    parser.add_argument(
        "--output",
        default='', # 输出文件夹路径
        help="A file or directory to save output visualizations. "
             "If not given, will show output in an OpenCV window.",
    )
 
    parser.add_argument(
        "--confidence-threshold",
        type=float,
        default=0.5, #置信度阈值
        help="Minimum score for instance predictions to be shown",
    )
    parser.add_argument(
        "--opts",
        help="Modify config options using the command-line 'KEY VALUE' pairs",
        default=[],
        nargs=argparse.REMAINDER,
    )
    return parser
 
 
def featuremap_2_heatmap(feature_map):
    assert isinstance(feature_map, torch.Tensor)
    
    # 1*256*200*256 # feat的维度要求,四维
    feature_map = feature_map.detach()
 
    # 1*256*200*256->1*200*256
    heatmap = feature_map[:,0,:,:]*0
    for c in range(feature_map.shape[1]):
        heatmap+=feature_map[:,c,:,:]
    heatmap = heatmap.cpu().numpy()
    heatmap = np.mean(heatmap, axis=0)
 
    heatmap = np.maximum(heatmap, 0)
    heatmap /= np.max(heatmap)
 
    return heatmap
 
def draw_feature_map(img_path, save_dir):
   
    args = get_parser().parse_args()
    cfg = setup_cfg(args)
    logger = setup_logger()
    logger.info("Arguments: " + str(args))
 
    from predictor import VisualizationDemo
    demo = VisualizationDemo(cfg)
    for imgs in tqdm.tqdm(os.listdir(img_path)):
        img = read_image(os.path.join(img_path, imgs), format="BGR")
        start_time = time.time()
        predictions = demo.run_on_image(img) # 后面需对网络输出做一定修改,
        # 会得到一个字典P3-P7的输出
        logger.info(
            "{}: detected in {:.2f}s".format(
                imgs, time.time() - start_time))
        i=0
        for featuremap in list(predictions.values()):
            heatmap = featuremap_2_heatmap(featuremap)
            # 200*256->512*640
            heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))  # 将热力图的
            # 大小调整为与原始图像相同         
            heatmap = np.uint8(255 * heatmap)  # 将热力图转换为RGB格式
            # 512*640*3
            heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # 将热力图应用于原
            # 始图像       
            superimposed_img = heatmap * 0.7 + 0.3*img  # 热力图强度因子,修改参数,
            得到合适的热力图
            cv2.imwrite(os.path.join(save_dir,imgs+str(i)+'.jpg'),
            superimposed_img)  # 将图像保存                    
            i=i+1
 
 
from argparse import ArgumentParser
 
def main():
    args = get_parser().parse_args()
    cfg = setup_cfg(args)
    draw_feature_map(args.input,args.output)
 
if __name__ == '__main__':
    main()

第二步,修改网络输出

# 修改可视化需要的文件predict.py中的run_on_image,对相应代码进行注释
 
def run_on_image(self, image):
 
        vis_output = None
        predictions = self.predictor(image)
        # Convert image from OpenCV BGR format to Matplotlib RGB format.
        # image = image[:, :, ::-1]
        # visualizer = Visualizer(image, self.metadata, 
        # instance_mode=self.instance_mode)             
        # if "panoptic_seg" in predictions:
        #     panoptic_seg, segments_info = predictions["panoptic_seg"]
        #     vis_output = visualizer.draw_panoptic_seg_predictions(
        #         panoptic_seg.to(self.cpu_device), segments_info
        #     )
        # else:
        #     if "sem_seg" in predictions:
        #         vis_output = visualizer.draw_sem_seg(
        #             predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
        #         )
        #     if "instances" in predictions:
        #         instances = predictions["instances"].to(self.cpu_device)
        #         vis_output =
        # visualizer.draw_instance_predictions(predictions=instances)
 
        return predictions
        # return predictions, vis_output

然后对detectron2/engine/defaults.py中 class DefaultPredictor进行输出修改

def __call__(self, original_image):
       
        with torch.no_grad():  
            # Apply pre-processing to image.
            if self.input_format == "RGB":
                # whether the model expects BGR inputs or RGB
                original_image = original_image[:, :, ::-1]
            height, width = original_image.shape[:2]
            image =  self.transform_gen.get_transform(original_image).apply_image(original_image)
            image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
 
            inputs = {"image": image, "height": height, "width": width}
            predictions = self.model([inputs])
            # ----------------------------------------
            # predictions = self.model([inputs])[0]
            # ----------------------------------------
            return predictions

就是对网络进行修改detectron2/modeling/meta_arch/rcnn.py,主要修改
class GeneralizedRCNN(nn.Module)中的inference函数,只输出骨干,会得到一个字典,P3-P7,不同骨干输出不同

def inference(self, batched_inputs, detected_instances=None, do_postprocess=True, init=False):
       
        assert not self.training
 
        images = self.preprocess_image(batched_inputs)
        features = self.backbone(images.tensor)
        return features

最后运行python heatmap.py 就会得到5张热力图代表想应的P3-P7输出,示意图如下所示:
在这里插入图片描述
本文主要是防止原始内容失效,方便以后个人使用。转载自此处

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值