mmdetection特征可视化V2
前言
在上一篇博客中介绍了特征图可视化,发现还可以对其简化,不用修改一大堆东西,直接在我们想要可视化的地方直接调用可视化函数即可,方便大家在debug的时候可以快速的看到自己想要看的特征图
一、特征图可视化
1.新建feature_visualization.py文件
该文件我自己建立在tools文件夹下面,自己也可以新建一个文件夹放进去,里面主要包含两个函数,跟上一篇博客中的基本一样:
import cv2
import mmcv
import numpy as np
import os
import torch
import matplotlib.pyplot as plt
def featuremap_2_heatmap(feature_map):
assert isinstance(feature_map, torch.Tensor)
feature_map = feature_map.detach()
heatmap = feature_map[:,0,:,:]*0
heatmaps = []
for c in range(feature_map.shape[1]):
heatmap+=feature_map[:,c,:,:]
heatmap = heatmap.cpu().numpy()
heatmap = np.mean(heatmap, axis=0)
heatmap = np.maximum(heatmap, 0)
heatmap /= np.max(heatmap)
heatmaps.append(heatmap)
return heatmaps
def draw_feature_map(features,save_dir = 'feature_map',name = None):
i=0
if isinstance(features,torch.Tensor):
for heat_maps in features:
heat_maps=heat_maps.unsqueeze(0)
heatmaps = featuremap_2_heatmap(heat_maps)
# 这里的h,w指的是你想要把特征图resize成多大的尺寸
# heatmap = cv2.resize(heatmap, (h, w))
for heatmap in heatmaps:
heatmap = np.uint8(255 * heatmap)
# 下面这行将热力图转换为RGB格式 ,如果注释掉就是灰度图
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
superimposed_img = heatmap
plt.imshow(superimposed_img,cmap='gray')
plt.show()
else:
for featuremap in features:
heatmaps = featuremap_2_heatmap(featuremap)
# heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])) # 将热力图的大小调整为与原始图像相同
for heatmap in heatmaps:
heatmap = np.uint8(255 * heatmap) # 将热力图转换为RGB格式
# heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
# superimposed_img = heatmap * 0.5 + img*0.3
superimposed_img = heatmap
plt.imshow(superimposed_img,cmap='gray')
plt.show()
# 下面这些是对特征图进行保存,使用时取消注释
# cv2.imshow("1",superimposed_img)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
# cv2.imwrite(os.path.join(save_dir,name +str(i)+'.png'), superimposed_img)
# i=i+1
2.使用方法
将上述代码文件准备好后,后面的步骤就很简单了,直接在你想使用的地方直接调用函数即可,实例如下,比如我们用Faster_rcnn网络,就在two_stage.py文件里面,找到**extract_feat()**函数,增加两行代码,如下所示:
def extract_feat(self, img):
"""Directly extract features from the backbone+neck."""
x = self.backbone(img)
# 可视化resnet产生的特征
from tools.feature_visualization import draw_feature_map
draw_feature_map(x)
if self.with_neck:
x = self.neck(x)
# 可视化FPN产生的特征
from tools.feature_visualization import draw_feature_map
draw_feature_map(x)
return x
用起来还是非常简单的,假如你用了其他的网络检测模型,需要在mmdet/models/detectors下面的文件中找到你所用的detector,这个在model的config文件看你model的type就可以查到