如何选取网络中间的输出?并极简方法可视化

在这里插入图片描述

一、提供几种方法

1. 直接修改模型的 forward 方法:

你可以通过修改模型的 forward 方法,直接返回第 layer5 层的输出。例如:

class ModifiedModel(nn.Module):
    def __init__(self, original_model):
        super(ModifiedModel, self).__init__()
        self.features = nn.Sequential(*list(original_model.children())[:6])  # 假设 layer5 是第六个子模块

    def forward(self, x):
        x = self.features(x)  # 这里 x 就是 layer5 的输出
        return x

这种方法的缺点是修改了模型的结构,并且可能不太灵活。

2. 使用 hook 函数:

hook 是一种更为灵活的方法,可以在模型执行时捕获中间层的输出。

outputs = []

def hook_fn(module, input, output):
    outputs.append(output)

layer5 = model.layer5  # 假设 layer5 是模型中的一个属性
handle = layer5.register_forward_hook(hook_fn)
# Forward pass
_ = model(input_tensor)
# Layer5 的输出现在存储在 outputs 中
layer5_output = outputs[0]
# 完成后注销 hook
handle.remove()

3. 直接访问模型的子模块:

如果你知道 layer5 的具体位置,可以直接在 forward 函数中通过 model.layer5(input) 的方式调用,获取输出。例如:

layer5_output = model.layer5(input)

这种方式假设你已经知道 layer5 是模型的一个属性。

4. 使用 nn.Sequential 来提取中间层:

这种方法是通过 nn.Sequential 创建一个包含所需层的子模型,从而提取中间层输出。

sub_model = nn.Sequential(*list(model.children())[:6])  # 获取到 layer5 的输出
layer5_output = sub_model(input_tensor)

2. 如何选择方法

  • 方法1 适合需要频繁获取特定层输出的情况,但它会修改模型结构。
  • 方法2 是最灵活的,不需要修改模型结构,非常适合调试或分析模型行为。
  • 方法3 适合在 forward pass 中直接获取某一层的输出,前提是你知道该层的具体位置。
  • 方法4 适合想要提取多层输出的情况。

3. 实际操作——提取特征并可视化

1. 可视化

可视化部分参考:https://blog.csdn.net/weixin_41496173/article/details/135713973

  • 首先是可视化代码
    3个入参分别是:提取的特征,前处理后的图,原图
def show_image_relevance(image_relevance, image, orig_image):
    # 这个函数的作用是将模型的注意力权重与原始图像结合,生成一个可视化的图像,
    # 显示模型在图像上的关注区域。参数包括 image_relevance(注意力权重),
    # image(模型输入图像),和 orig_image(原始图像)。
    def show_cam_on_image(img, mask):
        """
        用于将热力图叠加在图像上。img 是输入图像,mask 是注意力权重的热力图。
        Args:
            img:
            mask:

        Returns:

        """
        # 使用 OpenCV 的 applyColorMap 函数,将注意力权重(mask)转换为颜色热力图。
        # COLORMAP_JET 是一种常用的颜色映射方案,热力图的颜色表示不同强度的关注。
        heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
        heatmap = np.float32(heatmap) / 255  # 将热力图的像素值归一化到 [0, 1] 范围,以便与输入图像结合。
        cam = heatmap + np.float32(img)  # 将热力图叠加到输入图像上,生成一个带有关注区域的图像。
        cam = cam / np.max(cam)  # 再次将图像归一化,以确保像素值在 [0, 1] 范围内,防止溢出。
        return cam

    fig, axs = plt.subplots(1, 2)
    axs[0].imshow(orig_image)
    axs[0].axis('off')

    dim = int(image_relevance.numel() ** 0.5)
    image_relevance = image_relevance.reshape(1, 1, dim, dim)
    image_relevance = torch.nn.functional.interpolate(image_relevance.float(), size=448, mode='bilinear')
    image_relevance = image_relevance.reshape(448, 448).cuda().data.cpu().numpy()
    image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
    image = torch.Tensor(image).data.cpu().numpy()
    image = (image - image.min()) / (image.max() - image.min())
    vis = show_cam_on_image(image, image_relevance)
    vis = np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    axs[1].imshow(vis)
    axs[1].axis('off')

2. 提取特征

  • 方式一
    知道是提取网络第几层的特征,且与网络前几层是连续的,则直接把网络的前几层拿出来推理就好。
    save_path = './XXX/layer_3'
    sub_model = nn.Sequential(*list(model.body.children())[:4])
    
    os.makedirs(save_path, exist_ok=True)
    layer_output = sub_model(tensor_batch.cuda()).cpu()
    for i in range(layer_output.shape[1]):
        show_image_relevance(layer_output[0, i, :, :], np_img, im)
        plt.savefig(os.path.join(save_path, f'show_pic_no{i}.jpg'))
        plt.close()  # 关闭当前图像窗口
    
  • 方式二
    提取的特征是head部分,通过前几层这样取不出,有个最简单的方法就是——1 修改一下return;
    在需要提取特征的地方直接加return即可。
  • 12
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Thomas_Cai

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值