pytorch register_hook 学习
register_hook 简介:
pytorch 的 hook 函数分为 torch.Tensor 和 torch.nn.Module 两类, 分别对应:
torch.Tensor.register_hook
torch.nn.Module.register_backward_hook
参考:
https://blog.csdn.net/qq_42110481/article/details/81043932
register_hook 的作用:
即对x求导时,对x的导数进行操作,并且register_hook的参数只能以函数的形式传过去。
给出一段来自于CAM的代码
class FeatureExtractor():
""" Class for extracting activations and
registering gradients from targetted intermediate layers """
def __init__(self, model, target_layers):
self.model = model
self.target_layers = target_layers
self.gradients = []
def save_gradient(self, grad):
self.gradients.append(grad)
def __call__(self, x):
outputs = []
self.gradients = []<