hook的使用场景多种多样,本文将使用hooks来简单可视化卷积神经网络的特征提取,主要包括以下几个步骤:
- 创建CNN特征提取器(即需要可视化特征图的模型)
- 创建一个保存hook内容的对象
- 为需要可视化的卷积层创建hook
- 可视化特征图
1.创建CNN特征提取器
import torch
import torchvision
model= torchvision.models.resnet34(pretrained=True)
if torch.cuda.is_available():
model.cuda()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
2.创建保存hook内容的对象
class SaveOutput:
def __init__(self):
self.outputs = []
def __call__(self, module, module_in, module_out):
self.outputs.append(module_out)
def clear(self):
self.outputs=[]
save_output = SaveOutput()
3.为需要可视化的卷积层创建hook
可视化卷积层之前,需要查看模型网络层对应的名称,
比如下图中需要查看layer1的特征图,需要找到对应特征图名称【model.module.backbone.layer1[2].relu3】
layer_name = [
model.module.backbone.layer1[2].relu3,
model.module.backbone.layer2[3].relu3,
model.module.backbone.layer3[5].relu3,
model.module.backbone.layer4[2].relu3]
for layer in layer_name:
handle = layer.register_forward_hook(save_output)
hook_handles.append(handle)
4.特征图可视化
该部分主要是讲第三部分提取到的特征图进行可视化,可视化的方法有两种,一种是将各个通道的特征图求和相加,输出完整特征图,还有一种是将各个通道的特征图进行拼接,将每个通道的特征图分别输出。
# 将特征图进行拼接,输出每个通道特征图
def grid_gray_image(imgs, each_row: int):
'''
imgs shape: batch * size (e.g., 64x32x32, 64 is the number of the gray images, and (32, 32) is the size of each gray image)
each_row: Number of feature maps per line
'''
row_num = imgs.shape[0]//each_row
for i in range(row_num):
img = imgs[i*each_row]
img = (img - img.min()) / (img.max() - img.min())
for j in range(1, each_row):
tmp_img = imgs[i*each_row+j]
tmp_img = (tmp_img - tmp_img.min()) / (tmp_img.max() - tmp_img.min())
img = np.hstack((img, tmp_img))
if i == 0:
ans = img
else:
ans = np.vstack((ans, img))
return ans
#将特征图转换成热力图(将各个通道特征图相加求和)
def featuremap_2_heatmap(feature_map):
assert isinstance(feature_map, torch.Tensor)
feature_map = feature_map.detach()
heatmap = feature_map[:,0,:,:]*0
heatmaps = []
for c in range(feature_map.shape[1]):
heatmap+=feature_map[:,c,:,:]
heatmap = heatmap.cpu().numpy()
heatmap = np.mean(heatmap, axis=0)
heatmap = np.maximum(heatmap, 0)
heatmap /= np.max(heatmap)
heatmaps.append(heatmap)
return heatmaps
# 绘制特征图
def draw_feature_map(features,save_dir = '',name = None):
i=0
if isinstance(features,torch.Tensor):
for heat_maps in features:
heat_maps=heat_maps.unsqueeze(0)
heatmaps = featuremap_2_heatmap(heat_maps)
for heatmap in heatmaps:
heatmap = np.uint8(255 * heatmap)
# 下面这行将热力图转换为RGB格式 ,如果注释掉就是灰度图
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
superimposed_img = heatmap
plt.imshow(superimposed_img,cmap='gray')
plt.show()
else:
a = -1
for i , featuremap in enumerate(features):
heatmaps = featuremap_2_heatmap(featuremap)
for heatmap in heatmaps:
# heatmap = (heatmap-heatmap.min())/(heatmap.max()-heatmap.min())
heatmap = np.uint8(255 * heatmap) # 将热力图转换为RGB格式
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
superimposed_img = heatmap
plt.title(name[i%len(name)])
plt.imshow(superimposed_img,cmap='gray')
plt.show()
cv2.imwrite(os.path.join(save_dir, 'rpn_cls'+ str(i) +'.jpg'), superimposed_img)
5.特征图效果展示
本文以旋转目标检测网络ReDet特征图进行展示
第一行为backbone提取的特征图,第二行为rpn_cls特征图,第三行为rpn_reg特征图