Pytorch中反向传播计算图问题

Pytorch中反向传播计算图问题

问题复现:
  • pytorch中进行梯度计算的过程中,如果计算图已经完成了构建,那么即使变更了计算图中的数值结构,计算结果或出现的报错也不会改变
示例解析:

构架 x ∗ w x*w xw,并计算对应的mse loss:

# -*- coding: utf-8 -*-
import torch
import numpy as np

x = torch.ones((1, 3))
w = torch.full([3, 1], 2.)
mse = torch.nn.functional.mse_loss(torch.ones(1), x*w)

计算mse对应 w w w梯度:

torch.autograd.grad(mse, [w])

这时会报错:

Traceback (most recent call last):
  File "C:\data\PyCharm 2021.1.2\plugins\python\helpers\pydev\_pydevd_bundle\pydevd_exec2.py", line 3, in Exec
    exec(exp, global_vars, local_vars)
  File "<input>", line 1, in <module>
  File "C:\data\anaconda\envs\torch\lib\site-packages\torch\autograd\__init__.py", line 225, in grad
    inputs, allow_unused, accumulate_grad=False)
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

会提示 w w w权重在构建的时候require_grad为False,那我们将w的赋予梯度计算信息,再次计算梯度:

w.requires_grad_()
torch.autograd.grad(mse, [w])

这时候还是会报错:

tensor([[2.],
        [2.],
        [2.]], requires_grad=True)

Traceback (most recent call last):
  File "C:\data\PyCharm 2021.1.2\plugins\python\helpers\pydev\_pydevd_bundle\pydevd_exec2.py", line 3, in Exec
    exec(exp, global_vars, local_vars)
  File "<input>", line 1, in <module>
  File "C:\data\anaconda\envs\torch\lib\site-packages\torch\autograd\__init__.py", line 225, in grad
    inputs, allow_unused, accumulate_grad=False)
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

还是会提示没有梯度信息,这就是由于在构建mse计算图时,计算图已经构建完成,这时候即便更改计算图中权重 w w w的梯度请求信息,计算图也不会更新,从而报错;这时我们需要重新构建计算图,才能进行正常的计算。

mse = torch.nn.functional.mse_loss(torch.ones(1), x*w)
torch.autograd.grad(mse, [w])

输出结果:

(tensor([[0.6667],
        [0.6667],
        [0.6667]]),)
完整测试代码:
# -*- coding: utf-8 -*-
import torch
import numpy as np

x = torch.ones((1, 3))
w = torch.full([3, 1], 2.)
mse = torch.nn.functional.mse_loss(torch.ones(1), x*w)

grad = torch.autograd.grad(mse, [w])
print(grad)

w.requires_grad_()
torch.autograd.grad(mse, [w])

mse = torch.nn.functional.mse_loss(torch.ones(1), x*w)
torch.autograd.grad(mse, [w])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值