pytorch_trick(3): PyTorch中可微张量的in-place operation问题解决方法

PyTorch中可微张量的in-place operation问题解决方法

关于可微张量的in-place operation(对原对像修改操作)的相关讨论。

  • (1) 叶节点数值修改存在可微张量的in-place operation问题,会导致系统无法区分叶节点和其他节点的问题。
# in-place operation问题报错
# 但如果在计算过程中,我们使用in-place operation,让新生成的值替换w原始值,则会报错

w = torch.tensor(2., requires_grad = True)
w -= w * 2  

'''
RuntimeError                              Traceback (most recent call last)
<ipython-input-11-285e4fe2cdf5> in <module>
      1 w = torch.tensor(2., requires_grad = True)
----> 2 w -= w * 2

RuntimeError: a leaf Variable that requires grad is being used in an in-place operation."
'''

从报错信息中可知,PyTorch中不允许叶节点使用in-place operation,根本原因是会造成叶节点和其他节点类型混乱。

  • 修改w值
    不过,虽然可微张量不允许in-place operation,但却可以通过其他方法进行对w进行修改。
w = torch.tensor(2., requires_grad = True)
w.is_leaf           # True,w是叶节点

w = w * 2           # tensor(4., grad_fn=<MulBackward0>)
w.is_leaf           # False,w不是叶节点

# 无法通过反向传播求其导数
w.backward()       
w.grad

'''
<ipython-input-7-8623099507de>:9: 
UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. 
Its .grad attribute won't be populated during autograd.backward(). 
If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. 
If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. 
'''

但是该方法会导致叶节点丢失,无法反向传播求导。而在一张计算图中,缺少了对叶节点反向传播求导数的相关运算,计算图也就失去了核心价值。因此在实际操作过程中,应该尽量避免导致叶节点丢失的相关操作。

(2) 叶节点数值修改方法

  • 使用with torch.no_grad()语句或者torch.detach()方法

  当然,如果出现了一定要修改叶节点的取值的情况,典型的如梯度下降过程中利用梯度值修改参数值时,可以使用此前介绍的暂停追踪的方法,如使用with torch.no_grad()语句或者torch.detach()方法,使得修改叶节点数值时暂停追踪,然后再生成新的叶节点带入计算,如:

w = torch.tensor(2., requires_grad = True)     

# 利用with torch.no_grad()暂停追踪
with torch.no_grad():       
    w -= w * 2              # tensor(-2., requires_grad=True)

w.is_leaf                   # True,w是叶节点
w = torch.tensor(2., requires_grad = True)

# 利用detach生成新变量
w.detach_()             # tensor(2.)
w -= w * 2              # tensor(-2.)
w.requires_grad = True  # tensor(-2., requires_grad=True)

w.is_leaf               # True,w是叶节点
  • 使用.data方法

当然,此处我们介绍另一种方法,.data来返回可微张量的取值,从在避免在修改的过程中被追踪

w = torch.tensor(2., requires_grad = True)
w.data              # tensor(2.),查看张量的数值
w                   # tensor(2., requires_grad=True),但不改变张量本身可微性

# .data方法,对其数值进行修改
w.data -= w * 2     
w                   # tensor(-2., requires_grad=True)

w.is_leaf           # True,w仍然是叶节点
  • 15
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

白白白飘

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值