Pytorch中如何获取某层Module?(方便改变梯度,获取特征图,CAM等)

本文介绍了如何在PyTorch模型中定位复杂的层结构,包括通过层名称从结构上查找和直接通过名字获取的方法,以进行梯度调整或可视化操作。以resnet34为例,详细演示了如何通过'layer1.1.conv2'这样的名字找到对应的层。
摘要由CSDN通过智能技术生成

1、引入

我们有时候需要改变某层梯度,或者获取某层梯度图,甚至画CAM可视化效果时,需要定位到模型中某层,但往往模型中层的结构层层嵌套,难。

2、方法

我们以torchvision.models中resnet34为例,来获取其中名"layer1.1.conv2"的层,以下是resnet34结构部分截图:在这里插入图片描述
首先我们先看看所有层的名字:

for n,v in net.named_parameters():
    print(n)

Output:
	conv1.weight
	bn1.weight
	bn1.bias
	layer1.0.conv1.weight
	layer1.0.bn1.weight
	layer1.0.bn1.bias
	layer1.0.conv2.weight
	layer1.0.bn2.weight
	layer1.0.bn2.bias
	layer1.1.conv1.weight
	layer1.1.bn1.weight
	layer1.1.bn1.bias
	layer1.1.conv2.weight
	layer1.1.bn2.weight
	layer1.1.bn2.bias
	layer1.2.conv1.weight

2.1、方法一:从结构上获取层

根据结果获取层的原则只有一个:非sequential对象的层,直接通过.名字获取,sequential对象的层通过[index]获取sequential对象里的某层
比如"layer1.1.conv2":

  • 1ayer1:1ayer1在一个非sequential对象中,所以直接通过model.layer1获取到该层
  • .1:layer1是一个sequential对象,获取它里面的组成需要通过[index]获取,所以通过model.layer1[1]获取到该层
  • .conv2:conv2是model.layer1[1]中的,而model.layer1[1]不是sequential对象,所以直接通过名字获取,所以通过model.layer1[1].conv2获取到
    所以如果你想对layer1.1.conv2做操作,直接通过lay=model.layer1[1].conv2就可以获取到该层,总结就是非纯数字的,直接通过名字,是纯数字的,通过上一层[index]来获取。

2.2、方式二:通过名字直接获取

我们已经知道想要操作的层的名字为"layer1.1.conv2",则通过以下方法获取该层:

name ="layer1.1.conv2"
for name, module in self.model.named_modules():             
    if name == self.module_name:
        #则module 就是该层
以下是基于pytorch写的hook提取模型特定特征图并可视化CAM的代码: ```python import torch import torch.nn as nn import cv2 import numpy as np class CAM(): def __init__(self, model, target_layer): self.model = model self.target_layer = target_layer self.features = [] self.grads = [] self.hook_fn = None def hook(self, module, input, output): self.features.append(output.cpu().data.numpy()) def hook_backward(self, module, grad_input, grad_output): self.grads.append(grad_output[0].cpu().data.numpy()) def get_cam(self, input_image, class_idx=None): self.features = [] self.grads = [] self.hook_fn = self.target_layer.register_forward_hook(self.hook) hook_fn_backward = self.target_layer.register_backward_hook(self.hook_backward) input_image = input_image.to(device) self.model.zero_grad() output = self.model(input_image) if class_idx is None: class_idx = np.argmax(output.cpu().data.numpy()) one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32) one_hot[0][class_idx] = 1 one_hot = torch.from_numpy(one_hot).requires_grad_(True) one_hot = torch.sum(one_hot * output) self.model.zero_grad() one_hot.backward() grads_val = self.grads[-1] target = self.features[-1] weights = np.mean(grads_val, axis=(2, 3))[0, :] cam = np.zeros(target.shape[2:], dtype=np.float32) for i, w in enumerate(weights): cam += w * target[0, i, :, :] cam = np.maximum(cam, 0) cam = cv2.resize(cam, input_image.shape[2:]) cam = cam - np.min(cam) cam = cam / np.max(cam) self.hook_fn.remove() hook_fn_backward.remove() return cam # 使用示例: class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.fc1 = nn.Linear(128 * 8 * 8, 256) self.fc2 = nn.Linear(256, 10) def forward(self, x): x = nn.functional.relu(self.conv1(x)) x = nn.functional.max_pool2d(x, 2) x = nn.functional.relu(self.conv2(x)) x = nn.functional.max_pool2d(x, 2) x = nn.functional.relu(self.conv3(x)) x = nn.functional.max_pool2d(x, 2) x = x.view(-1, 128 * 8 * 8) x = nn.functional.relu(self.fc1(x)) x = self.fc2(x) return x model = Net() target_layer = model.conv3 cam = CAM(model, target_layer) # 加载图像 image_path = 'test.jpg' image = cv2.imread(image_path) image = cv2.resize(image, (32, 32)) image = np.transpose(image, (2, 0, 1)) image = image.astype(np.float32) / 255. image = torch.from_numpy(image) image = image.unsqueeze(0) # 生成CAM cam_map = cam.get_cam(image) # 可视化CAM heatmap = cv2.applyColorMap(np.uint8(255 * cam_map), cv2.COLORMAP_JET) heatmap = np.float32(heatmap) / 255 cam_image = heatmap + np.float32(image[0]) cam_image = cam_image / np.max(cam_image) cv2.imshow('CAM', cam_image) cv2.waitKey(0) ``` 在上面的代码中,我们定义了一个CAM类,用来提取模型特定特征图,并生成对应的CAM图像。CAM类中包含了一个hook函数,用来提取目标特征图,以及一个hook_backward函数,用来提取特征图对应的梯度。在get_cam函数中,我们首先将输入图像经过模型前向传播,然后根据输出结果确定目标类别。接着,我们通过反向传播计算目标类别对应的特征图梯度,并利用这个梯度生成CAM图像。最后,我们将CAM图像和原始图像叠加起来,生成可视化的CAM图像。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

我是一个对称矩阵

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值