PyTorch中可微张量的in-place operation问题解决方法
关于可微张量的in-place operation(对原对像修改操作)的相关讨论。
- (1) 叶节点数值修改存在可微张量的in-place operation问题,会导致系统无法区分叶节点和其他节点的问题。
# in-place operation问题报错
# 但如果在计算过程中,我们使用in-place operation,让新生成的值替换w原始值,则会报错
w = torch.tensor(2., requires_grad = True)
w -= w * 2
'''
RuntimeError Traceback (most recent call last)
<ipython-input-11-285e4fe2cdf5> in <module>
1 w = torch.tensor(2., requires_grad = True)
----> 2 w -= w * 2
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation."
'''
从报错信息中可知,PyTorch中不允许叶节点使用in-place operation,根本原因是会造成叶节点和其他节点类型混乱。
- 修改w值
不过,虽然可微张量不允许in-place operation,但却可以通过其他方法进行对w进行修改。
w = torch.tensor(2., requires_grad = True)
w.is_leaf # True,w是叶节点
w = w * 2 # tensor(4., grad_fn=<MulBackward0>)
w.is_leaf # False,w不是叶节点
# 无法通过反向传播求其导数
w.backward()
w.grad
'''
<ipython-input-7-8623099507de>:9:
UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed.
Its .grad attribute won't be populated during autograd.backward().
If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor.
If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead.
'''
但是该方法会导致叶节点丢失,无法反向传播求导。而在一张计算图中,缺少了对叶节点反向传播求导数的相关运算,计算图也就失去了核心价值。因此在实际操作过程中,应该尽量避免导致叶节点丢失的相关操作。
(2) 叶节点数值修改方法
- 使用
with torch.no_grad()
语句或者torch.detach()
方法
当然,如果出现了一定要修改叶节点的取值的情况,典型的如梯度下降过程中利用梯度值修改参数值时,可以使用此前介绍的暂停追踪的方法,如使用
with torch.no_grad()
语句或者torch.detach()
方法,使得修改叶节点数值时暂停追踪,然后再生成新的叶节点带入计算,如:
w = torch.tensor(2., requires_grad = True)
# 利用with torch.no_grad()暂停追踪
with torch.no_grad():
w -= w * 2 # tensor(-2., requires_grad=True)
w.is_leaf # True,w是叶节点
w = torch.tensor(2., requires_grad = True)
# 利用detach生成新变量
w.detach_() # tensor(2.)
w -= w * 2 # tensor(-2.)
w.requires_grad = True # tensor(-2., requires_grad=True)
w.is_leaf # True,w是叶节点
- 使用
.data
方法
当然,此处我们介绍另一种方法,
.data
来返回可微张量的取值,从在避免在修改的过程中被追踪
w = torch.tensor(2., requires_grad = True)
w.data # tensor(2.),查看张量的数值
w # tensor(2., requires_grad=True),但不改变张量本身可微性
# .data方法,对其数值进行修改
w.data -= w * 2
w # tensor(-2., requires_grad=True)
w.is_leaf # True,w仍然是叶节点