深度之眼 PyTorch 训练营第 4 期(13):hook 函数

15 篇文章 1 订阅
15 篇文章 3 订阅

我们知道,在计算图运行的过程中,中间变量的数据是不会被保留的。想要保存中间变量,一个方法是使用 tnesor.retain_grad = True,但是这样会使数据被永久保存,造成存储空间的消耗,即我们可能只需要用到中间变量的值一次,但它却被永久保存了。为了解决这个问题,PyTorch 引入了 hook 函数。hook 函数分为张量的 hook 函数和神经网络层的 hook 函数两种。

1. 张量的 hook 函数

  • Tensor.register_hook(hook)

这个 hook 函数在每次计算反向传播的时候会被调用。tensor 的 hook 函数有个性质:

hook(grad) -> Tensor or None

即如果 hook 函数不返回任何值,张量的导数不会改变;如果有返回值,张量原本的导数会被返回值覆盖。张量的 hook 函数的使用方法如下:

  1. 定义计算图;
  2. 定义 hook 函数;
  3. 注册 hook 函数;
  4. 反向传播;
  5. 移除 hook 函数。

我们还是使用 PyTorch 折桂 4:torch.autograph 中计算导数的例子:

w = torch.tensor([1.], requires_grad=True) # 定义计算图
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

a_grad = list()

def grad_hook(grad): # 定义 hook 函数
    a_grad.append(grad)
    return grad * 3

handle = w.register_hook(grad_hook) # 注册 hook 函数

y.backward() # 反向传播

# 查看梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
print("a_grad[0]: ", a_grad[0])
print("w.grad: ", w.grad)
handle.remove() # 移除 hook 函数

运行后 PyTorch 先抛出来一个警告,说试图获得非叶子节点的导数。然后结果如下:

gradient: tensor([15.]) tensor([2.]) None None None
a_grad[0]:  tensor([5.])
w.grad:  tensor([15.])

可以看到,本来 a 的导数已经被释放掉了,但是我们通过 hook 函数把 a 的导数保存了起来。由于有返回值,导致 w 的导数被覆盖了。

2. 神经网络的 hook 函数

神经网络的 hook 函数分为前向传播的 hook 函数和反向传播的 hook 函数,分别为 register_forward_hook(hook)register_backward_hook(hook),使用方法与张量的 hook 函数大同小异。我们还是来看一个例子:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc(x)
        return x


def forward_hook(module, data_input, data_output):
    fmap_block.append(data_output)
    input_block.append(data_input)

def backward_hook(module, grad_input, grad_output):
    print("backward hook input:{}".format(grad_input))
    print("backward hook output:{}".format(grad_output))


# 初始化网络
net = Net()
net.fc.weight.detach().fill_(2)
net.fc.bias.data.detach().zero_()

# 注册hook
fmap_block = list()
input_block = list()
net.fc.register_forward_hook(forward_hook)
net.fc.register_backward_hook(backward_hook)

# inference
x = torch.ones(5)
output = net(x)

loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()

这里我们定义了一个单层全连接网络,全连接网络的权重为 2,使用 l1 损失函数,获得的全连接层的反向传播导数为:

backward hook input:(tensor([1.]), tensor([1.]))
backward hook output:(tensor([1.]),)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值