今天来讲一下 class ActivationsAndGradients,他用于从指定中间层提取激活值和注册对应的梯度,是GradCam库很重要的一部分。
class ActivationsAndGradients:
""" Class for extracting activations and
registering gradients from targeted intermediate layers """
def __init__(self, model, target_layers, reshape_transform):
self.model = model
self.gradients = []
self.activations = []
self.reshape_transform = reshape_transform
self.handles = []
for target_layer in target_layers:
self.handles.append(
target_layer.register_forward_hook(
self.save_activation))
# Backward compatibility with older pytorch versions:
if hasattr(target_layer, 'register_full_backward_hook'):
self.handles.append(
target_layer.register_full_backward_hook(
self.save_gradient))
else:
self.handles.append(
target_layer.register_backward_hook(
self.save_gradient))
def save_activation(self, module, input, output):
activation = output
if self.reshape_transform is not None:
activation = self.reshape_transform(activation)
self.activations.append(activation.cpu().detach())
def save_gradient(self, module, grad_input, grad_output):
# Gradients are computed in reverse order
grad = grad_output[0]
if self.reshape_transform is not None:
grad = self.reshape_transform(grad)
self.gradients = [grad.cpu().detach()] + self.gradients
def __call__(self, x):
self.gradients = []
self.activations = []
return self.model(x)
def release(self):
for handle in self.handles:
handle.remove()
init 初始化
类的初始化函数 __init__
接收四个参数:model
表示待处理的 PyTorch 模型,target_layers
表示指定提取激活值和梯度的层列表,reshape_transform
表示对激活值和梯度进行变换的函数或方法。
注册hook
通过循环遍历目标层列表 target_layers
,对每一层分别调用 register_forward_hook
和 register_backward_hook
方法,在网络中注册前向钩子和反向钩子。
前向钩子用于记录目标层的输出特征图,反向钩子用于记录目标层的梯度,将保存激活值和注册梯度计算函数绑定到该层,然后将得到的 handle对象存入 handles
列表中。
两个
hooks函数
在 save_activation
和 save_gradient
方法中,分别接收到注册的 forward hook 和 backward hook,并将其输出的激活值和梯度加入 activations
和 gradients
列表中。其中,如果存在 reshape_transform
,则会对激活值和梯度进行指定的变换(比如 Flatten),最后将其从 GPU 转移到 CPU 并 detach 出来。
call方法
__call__
方法重载了函数调用操作符 ()
,
首先清空 gradients
和 activations
两个列表,
然后模型前向传播,
在模型前向传播的过程中激活init时注册过的hook,获取到gradients
和 activations
release方法
release
方法则用于手动释放所有的 hooks,在不需要时对类进行释放。