ViT的可解释化-深度解析GradCam源码(2)-注册hook获取梯度

类ActivationsAndGradients用于从PyTorch模型的指定中间层提取激活值和梯度。它在目标层注册前向和反向钩子,以在前向传播时保存激活图,在反向传播时保存梯度。提供的reshape_transform函数可对数据进行变换。__call__方法执行模型前向传播并收集信息,release方法用于移除所有注册的钩子。
摘要由CSDN通过智能技术生成

今天来讲一下 class ActivationsAndGradients,他用于从指定中间层提取激活值和注册对应的梯度,是GradCam库很重要的一部分。

class ActivationsAndGradients:
    """ Class for extracting activations and
    registering gradients from targeted intermediate layers """

    def __init__(self, model, target_layers, reshape_transform):
        self.model = model
        self.gradients = []
        self.activations = []
        self.reshape_transform = reshape_transform
        self.handles = []
        for target_layer in target_layers:
            self.handles.append(
                target_layer.register_forward_hook(
                    self.save_activation))
            # Backward compatibility with older pytorch versions:
            if hasattr(target_layer, 'register_full_backward_hook'):
                self.handles.append(
                    target_layer.register_full_backward_hook(
                        self.save_gradient))
            else:
                self.handles.append(
                    target_layer.register_backward_hook(
                        self.save_gradient))

    def save_activation(self, module, input, output):
        activation = output
        if self.reshape_transform is not None:
            activation = self.reshape_transform(activation)
        self.activations.append(activation.cpu().detach())

    def save_gradient(self, module, grad_input, grad_output):
        # Gradients are computed in reverse order
        grad = grad_output[0]
        if self.reshape_transform is not None:
            grad = self.reshape_transform(grad)
        self.gradients = [grad.cpu().detach()] + self.gradients

    def __call__(self, x):
        self.gradients = []
        self.activations = []
        return self.model(x)

    def release(self):
        for handle in self.handles:
            handle.remove()

init 初始化

类的初始化函数 __init__ 接收四个参数:model 表示待处理的 PyTorch 模型,target_layers 表示指定提取激活值和梯度的层列表,reshape_transform 表示对激活值和梯度进行变换的函数或方法。

注册hook

通过循环遍历目标层列表 target_layers,对每一层分别调用 register_forward_hookregister_backward_hook 方法,在网络中注册前向钩子和反向钩子。

前向钩子用于记录目标层的输出特征图,反向钩子用于记录目标层的梯度,将保存激活值和注册梯度计算函数绑定到该层,然后将得到的 handle对象存入 handles 列表中。

两个hooks函数

save_activationsave_gradient 方法中,分别接收到注册的 forward hook 和 backward hook,并将其输出的激活值和梯度加入 activationsgradients 列表中。其中,如果存在 reshape_transform,则会对激活值和梯度进行指定的变换(比如 Flatten),最后将其从 GPU 转移到 CPU 并 detach 出来。

call方法

__call__ 方法重载了函数调用操作符 ()

首先清空 gradientsactivations 两个列表,

然后模型前向传播,

在模型前向传播的过程中激活init时注册过的hook,获取到gradientsactivations

release方法

release 方法则用于手动释放所有的 hooks,在不需要时对类进行释放。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值