课程笔记:HOOK函数

hook函数可以分为两部分:关于tensor(第一种)和关于module(第二三四种)
在这里插入图片描述

tensor.register_hook

在反向传播完成时,非叶子结点的梯度会消失
tensor.register_hook作用:
1)完成保存非叶子结点的梯度
2)修改叶子结点的值
在这里插入图片描述
例如:保存a的梯度值;修改w的梯度值
在这里插入图片描述

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

a_grad = list()

def grad_hook(grad):
    a_grad.append(grad)
# 通过a.register_hook保存a_grad(中间结点的梯度)值
handle = a.register_hook(grad_hook)

y.backward()
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

a_grad = list()

def grad_hook(grad):
    grad *= 2
    return grad*3
# 通过a.register_hook修改w的梯度值
handle = w.register_hook(grad_hook)

y.backward()

Module.register_forward_hook

获取forward种的feature map(是在某层conv执行之后再执行register_forward_hook)
在这里插入图片描述

def forward_hook(module, data_input, data_output):
    fmap_block.append(data_output)
    input_block.append(data_input)
# 注册hook
net.conv1.register_forward_hook(forward_hook)

Module.register_forward_pre_hook

在这里插入图片描述

def forward_pre_hook(module, data_input):
    print("forward_pre_hook input:{}".format(data_input))
# 注册hook
net.conv1.register_forward_pre_hook(forward_pre_hook)

Module.register_backward_hook

在这里插入图片描述

def backward_hook(module, grad_input, grad_output):
    print("backward hook input:{}".format(grad_input))
    print("backward hook output:{}".format(grad_output))
# 注册hook
net.conv1.register_backward_hook(backward_hook)

整个过程下,register_forward_hook、register_forward_pre_hook、register_backward_hook的运行过程
可见,在模型实例化后再初始化再注册hook

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 3)
        self.pool1 = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        return x

def forward_hook(module, data_input, data_output):
    fmap_block.append(data_output)
    input_block.append(data_input)

def forward_pre_hook(module, data_input):
    print("forward_pre_hook input:{}".format(data_input))

def backward_hook(module, grad_input, grad_output):
    print("backward hook input:{}".format(grad_input))
    print("backward hook output:{}".format(grad_output))

# 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()

# 注册hook
fmap_block = list()
input_block = list()
net.conv1.register_forward_hook(forward_hook)
net.conv1.register_forward_pre_hook(forward_pre_hook)
net.conv1.register_backward_hook(backward_hook)

# inference
fake_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * W
output = net(fake_img)

loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()

小结

HOOK函数的2.3.4种(即关于module部分)在module中的call函数中执行的。
call函数种的顺序是:
forward_pre_hook
forward
forward_hook
backward_hook
在这里插入图片描述
由此可见,module中的call函数并不是只有forward函数,而是借助hook函数实现其他的功能

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值