定义钩子函数
import torchvision.utils as vutil
import cv2
def hook_func(module, input, output):
"""
Hook function of register_forward_hook
Parameters:
-----------
module: module of neural network
input: input of module
output: output of module
"""
image_name = get_image_name_for_hook(module)
data = output.clone().detach()
# data = data.permute(1, 0, 2, 3)
# vutil.save_image(data, image_name, pad_value=0.5) # 这保存的是每个通道捕捉的语义
data = data.permute(1,0,2,3).cpu().squeeze()
pic = (np.mean(data.numpy(),axis=0)*255).astype(np.uint8)
feature=cv2.resize(pic,(512,512))
# 根据图像的像素值中最大最小值,将特征图的像素值归一化到了[0,1];
feature = (feature - np.amin(feature))/(np.amax(feature) - np.amin(feature) + 1e-5) # 注意要防止分母为0!
feature = np.round(feature * 255) # [0, 1]——[0, 255],为cv2.imwrite()函数而进行
feature = cv2.applyColorMap(np.array(feature,np.uint8),2) # 给特征图个颜色 热力图
cv2.imwrite(image_name,feature)
INSTANCE_FOLDER = "VIS_results"
def get_image_name_for_hook(module):
"""
Generate image filename for hook function
Parameters:
-----------
module: module of neural network
"""
os.makedirs(INSTANCE_FOLDER, exist_ok=True)
base_name = str(module).split('(')[0]
index = 0
image_name = '.' # '.' is surely exist, to make first loop condition True
while os.path.exists(image_name):
index += 1
image_name = os.path.join(
INSTANCE_FOLDER, '%s_%d.png' % (base_name, index))
return image_name
在验证处嵌入如下定义
with torch.no_grad():
# modules_for_plot = (torch.nn.ReLU, torch.nn.Conv2d,
# torch.nn.MaxPool2d, torch.nn.AdaptiveAvgPool2d)
names_for_plot = ('module.classifier.fusion','module.classifier.context','module.classifier.context.2','module.classifier.context.2.aspp')
for name, module in model.named_modules():
# if isinstance(module, modules_for_plot):
if name in names_for_plot:
module.register_forward_hook(hook_func)
for i, (images, labels) in tqdm(enumerate(loader)):
if i>=20:
break
部分参照:https://blog.csdn.net/bby1987/article/details/109590108