pytorch tensor: 获得中间节点的梯度

1. 输入要梯度,输出必须要梯度

我们只能指定计算图的leaf节点的requires_grad变量来决定改变量是否记录梯度,而不能指定它们运算产生的节点的requires_grad,它们是否要梯度取决于它们的输入节点,它们的输入节点只要有一个requires_grad是True,那么它的requires_grad也是True.

x = torch.randn(2, 100)
x.requires_grad = False
w = torch.randn(10, 100)
w2 = torch.randn(3, 10)
w.requires_grad = True
2w.requires_grad = True
y = x @ w.t()
z = y @ w2.t()
print(y.requires_grad, z.requires_grad)

z.sum().backward()

2. 获得中间节点的梯度

对于叶节点,如果我们指定了梯度,我们可以调用v.grad查看梯度;但是对于中间变量v.grad永远是None,如果要获得其梯度,就要使用register_hook,它会在调用这个变量的梯度反传的时候调用注册的函数.以下是一个简单的查看版本

import torch
from torch import nn

def hook(grad):
	print(grad)

x = torch.randn(2, 100)
x.requires_grad = False
w = torch.randn(10, 100)
w2 = torch.randn(3, 10)
w.requires_grad = True
w2.requires_grad = True
y = x @ w.t()
z = y @ w2.t()

y.register_hook(hook)
z.sum().backward()  # invoke get_grad('y') here

改进版

import torch


class GradCollector(object):
    def __init__(self):
        self.grads = {}

    def __call__(self, name: str):
        def hook(grad):
            self.grads[name] = grad
        return hook
    

x = torch.randn(2, 100)
x.requires_grad = False
w = torch.randn(10, 100)
w2 = torch.randn(3, 10)
w.requires_grad = True
w2.requires_grad = True
y = x @ w.t()
z = y @ w2.t()

grad_collector = GradCollector()
y.register_hook(grad_collector("y"))
z.register_hook(grad_collector('z'))

z.sum().backward()

print(grad_collector.grads['y'])
print(grad_collector.grads['z'])
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值