PyTorch中的hook

PyTorch中的hook

PyTorch中有一个重要的机制就是自动求导机制。

如果需要记下一些中间变量的结果,或者是人为对导数做一些改变的话,就需要使用hook。

三类hook:

(1) torch.tensor(3).register_hook,针对tensor

(2) torch.nn.Module.register_forward_hook,针对nn.Module

(3) torch.nn.Module.register_backward_hook,针对nn.Module

针对Tensor的hook

该函数在PyTorch中的实现如下:

    def register_hook(self, hook):
        r"""Registers a backward hook.

        The hook will be called every time a gradient with respect to the
        Tensor is computed. The hook should have the following signature::

            hook(grad) -> Tensor or None


        The hook should not modify its argument, but it can optionally return
        a new gradient which will be used in place of :attr:`grad`.

        This function returns a handle with a method ``handle.remove()``
        that removes the hook from the module.

        Example::

            >>> v = torch.tensor([0., 0., 0.], requires_grad=True)
            >>> h = v.register_hook(lambda grad: grad * 2)  # double the gradient
            >>> v.backward(torch.tensor([1., 2., 3.]))
            >>> v.grad

             2
             4
             6
            [torch.FloatTensor of size (3,)]

            >>> h.remove()  # removes the hook
        """
        if not self.requires_grad:
            raise RuntimeError("cannot register a hook on a tensor that "
                               "doesn't require gradient")
        if self._backward_hooks is None:
            self._backward_hooks = OrderedDict()
            if self.grad_fn is not None:
                self.grad_fn._register_hook_dict(self)
        handle = hooks.RemovableHandle(self._backward_hooks)
        self._backward_hooks[handle.id] = hook
        return handle

说的是什么意思呢?register_hook是在反传的时候用的,它的参数是一个函数,函数的形式为:

 hook(grad) -> Tensor or None

grad是这个tensor的梯度,返回是一个新的梯度值(可以代替原来的梯度值返回),但是它不应该改变原来的梯度值。同时这个hook也是可以移除的。它还会返回一个句柄handle,这个handle有一个remove()方法,可以用handle.remove()将这个hook移除,接下来举了一个例子说明register_hook怎么用。

在底层实现的时候,前面都是一些判断,重要的是注册了一个id,将这个hook和相应的tensor联系起来了。

针对nn.Module的hook

register_forward_hook(hook)

该函数在PyTorch中的实现如下:

    def register_forward_hook(self, hook):
        r"""Registers a forward hook on the module.

        The hook will be called every time after :func:`forward` has computed an output.
        It should have the following signature::

            hook(module, input, output) -> None or modified output

        The hook can modify the output. It can modify the input inplace but
        it will not have effect on forward since this is called after
        :func:`forward` is called.

        Returns:
            :class:`torch.utils.hooks.RemovableHandle`:
                a handle that can be used to remove the added hook by calling
                ``handle.remove()``
        """
        handle = hooks.RemovableHandle(self._forward_hooks)
        self._forward_hooks[handle.id] = hook
        return handle

上面是什么意思呢?它是在module上注册一个forward_hook,每次调用forward计算输出的时候,该函数就会被调用,这个hook的形式为:

hook(module, input, output) -> None or modified output

该函数不应该修改input和output的值,返回一个句柄(handle),它还有一个方法handle.remove(),可以用handle.remove()将这个hook移除。

在代码计算model(X)的时候,底层先调用forward函数完成前向的操作,然后判断是否存在register_forward_hook(hook),如果有的话,就调用相应的hook完成一定的功能。

register_backward_hook(hook)

首先看一下底层的代码:

    def register_backward_hook(self, hook):
        r"""Registers a backward hook on the module.

        The hook will be called every time the gradients with respect to module
        inputs are computed. The hook should have the following signature::

            hook(module, grad_input, grad_output) -> Tensor or None

        The :attr:`grad_input` and :attr:`grad_output` may be tuples if the
        module has multiple inputs or outputs. The hook should not modify its
        arguments, but it can optionally return a new gradient with respect to
        input that will be used in place of :attr:`grad_input` in subsequent
        computations.

        Returns:
            :class:`torch.utils.hooks.RemovableHandle`:
                a handle that can be used to remove the added hook by calling
                ``handle.remove()``

        .. warning ::

            The current implementation will not have the presented behavior
            for complex :class:`Module` that perform many operations.
            In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only
            contain the gradients for a subset of the inputs and outputs.
            For such :class:`Module`, you should use :func:`torch.Tensor.register_hook`
            directly on a specific input or output to get the required gradients.

        """
        handle = hooks.RemovableHandle(self._backward_hooks)
        self._backward_hooks[handle.id] = hook
        return handle

和前向的相同,都是在module上注册一个backward_hook,每次调用backward计算输出的时候,该函数就会被调用,这个hook的形式为:

hook(module, grad_input, grad_output) -> Tensor or None

参考

使用hook中的bug

pytorch中autograd以及hook函数详解

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorchhook机制是一种用于在计算图注册回调函数的机制。当计算图被执行时,这些回调函数会被调用,并且可以对计算图间结果进行操作或记录。 在PyTorch,每个张量都有一个grad_fn属性,该属性表示该张量是如何计算得到的。通过在这个grad_fn上注册一个hook函数,可以在计算图的每一步获取该张量的梯度,或者在该张量被使用时获取该张量的值。这些hook函数可以被用来实现一些调试、可视化或者改变计算图的操作。 下面是一个简单的例子,其我们在计算图的每一步都打印出间结果和梯度: ```python import torch def print_tensor_info(tensor): print('Tensor shape:', tensor.shape) print('Tensor value:', tensor) print('Tensor gradient:', tensor.grad) x = torch.randn(2, 2, requires_grad=True) y = x * 2 z = y.mean() # 注册一个hook函数,用来打印间结果和梯度 y.register_hook(print_tensor_info) # 执行计算图 z.backward() # 输出结果 print('x gradient:', x.grad) ``` 在这个例子,我们定义了一个张量x,并计算了y和z。我们在y上注册了一个hook函数,该函数在计算图的每一步都会被调用。然后我们执行了z的反向传播,计算出了x的梯度。最后,我们打印出了x的梯度。 需要注意的是,hook函数不能修改张量的值或梯度,否则会影响计算图的正确性。此外,hook函数只会在计算图的正向传播和反向传播时被调用,而不会在张量被直接使用时被调用。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值