pytorch特征图 可视化 清测有效,Mark一下。
引用自 https://my.oschina.net/u/4300877/blog/4693569
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 27 09:25:51 2020
@author: LX
"""
#%%特征可视化
import matplotlib.pyplot as plt
import cv2
import numpy as np
from PIL import Image
from torchvision import models, transforms
import torch
import timm
class SaveConvFeatures():
def __init__(self, m): # module to hook
self.hook = m.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output):
self.features = output.data
def remove(self):
self.hook.remove()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
t = transforms.Compose([transforms.Resize((224, 224)), #128, 128
transforms.ToTensor(),
transforms.Normalize(mean=0.1307, std=0.3081)])
img_file = r"C:\Users\LX\Pictures\elephant.jpg"
img = Image.open(img_file)
img = t(img).unsqueeze(0).to(device)
custom_model = models.resnet50(pretrained=True)
# custom_model = timm.create_model('resnest50d', pretrained=True)
# custom_model.layer4是自己需要查看特征输出的卷积层
hook_ref = SaveConvFeatures(custom_model.layer4)
with torch.no_grad():
custom_model(img)
conv_features = hook_ref.features # [1,2048,7,7]
print('特征图输出维度:', conv_features.shape) #其实得到特征图之后可以自己编写绘图程序
hook_ref.remove()
def show_feature_map(img_src, conv_features):
'''可视化卷积层特征图输出
img_src:源图像文件路径
conv_feature:得到的卷积输出,[b, c, h, w]
'''
img = Image.open(img_file).convert('RGB')
height, width = img.size
heat = conv_features.squeeze(0)#降维操作,尺寸变为(2048,7,7)
heat_mean = torch.mean(heat,dim=0)#对各卷积层(2048)求平均值,尺寸变为(7,7)
heatmap = heat_mean.numpy()#转换为numpy数组
heatmap /= np.max(heatmap)#minmax归一化处理
heatmap = cv2.resize(heatmap,(img.size[0],img.size[1]))#变换heatmap图像尺寸,使之与原图匹配,方便后续可视化
heatmap = np.uint8(255*heatmap)#像素值缩放至(0,255)之间,uint8类型,这也是前面需要做归一化的原因,否则像素值会溢出255(也就是8位颜色通道)
heatmap = cv2.applyColorMap(heatmap,cv2.COLORMAP_JET)#颜色变换
plt.imshow(heatmap)
plt.show()
# heatmap = np.array(Image.fromarray(heatmap).convert('L'))
superimg = heatmap*0.4+np.array(img)[:,:,::-1] #图像叠加,注意翻转通道,cv用的是bgr
cv2.imwrite('./superimg.jpg',superimg)#保存结果
# 可视化叠加至源图像的结果
img_ = np.array(Image.open('./superimg.jpg').convert('RGB'))
plt.imshow(img_)
plt.show()
show_feature_map(img_file, conv_features)