特征图可视化:可解释的深度学习模型(Pytorch)

在这里插入图片描述
在这里插入图片描述

定义钩子函数

import torchvision.utils as vutil
import cv2
def hook_func(module, input, output):
    """
    Hook function of register_forward_hook

    Parameters:
    -----------
    module: module of neural network
    input: input of module
    output: output of module
    """
    image_name = get_image_name_for_hook(module)
    data = output.clone().detach()
    # data = data.permute(1, 0, 2, 3)
    # vutil.save_image(data, image_name, pad_value=0.5) # 这保存的是每个通道捕捉的语义

    data = data.permute(1,0,2,3).cpu().squeeze()
    pic = (np.mean(data.numpy(),axis=0)*255).astype(np.uint8)
    feature=cv2.resize(pic,(512,512))
    # 根据图像的像素值中最大最小值,将特征图的像素值归一化到了[0,1];
    feature = (feature - np.amin(feature))/(np.amax(feature) - np.amin(feature) + 1e-5) # 注意要防止分母为0! 
    feature = np.round(feature * 255) # [0, 1]——[0, 255],为cv2.imwrite()函数而进行
    feature = cv2.applyColorMap(np.array(feature,np.uint8),2) # 给特征图个颜色  热力图
    cv2.imwrite(image_name,feature)

INSTANCE_FOLDER = "VIS_results"
def get_image_name_for_hook(module):
    """
    Generate image filename for hook function

    Parameters:
    -----------
    module: module of neural network
    """
    os.makedirs(INSTANCE_FOLDER, exist_ok=True)
    base_name = str(module).split('(')[0]
    index = 0
    image_name = '.'  # '.' is surely exist, to make first loop condition True
    while os.path.exists(image_name):
        index += 1
        image_name = os.path.join(
            INSTANCE_FOLDER, '%s_%d.png' % (base_name, index))
    return image_name

在验证处嵌入如下定义

	with torch.no_grad():
        # modules_for_plot = (torch.nn.ReLU, torch.nn.Conv2d,
        #                 torch.nn.MaxPool2d, torch.nn.AdaptiveAvgPool2d)
        names_for_plot = ('module.classifier.fusion','module.classifier.context','module.classifier.context.2','module.classifier.context.2.aspp')
        for name, module in model.named_modules():
            # if isinstance(module, modules_for_plot):
            if name in names_for_plot:
                module.register_forward_hook(hook_func)

        for i, (images, labels) in tqdm(enumerate(loader)):
            if i>=20:
                break

部分参照:https://blog.csdn.net/bby1987/article/details/109590108

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

星空•物语

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值