什么是inplace操作?
首先解释一下什么是inplace操作,在PyTorch中,我们直接在Tensor存储数据的内存上修改这个Tensor,而不是经过创建内存修改Tensor,我们称这种操作为inplace操作。
例子:
import torch
X = torch.tensor([1.0])
#不是inplace操作
X = X + 2
#是inplace操作
X.add_(2)
由于Tensor需要做自动微分,所以PyTorch会限制用户使用inplace操作的范围,当这个Tensor为叶节点并且requires_grad == True时,无法执行inplace操作。
主要原因在于,在我们求梯度之前,假如修改了某些参数的值会导致梯度计算出错。但是这个限制并不大,因为我们还可以使用tensor.data来更改tensor,因此tensor.data也被认为是不安全的。
import torch
X = torch.tensor([1.0],requires_grad = True)
X.add_(2)
"""
输出:
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
"""
但是并不是说一个需要梯度的叶子节点就无法执行inplace操作,只要处于with torch.no_grad():或者@no_grad之下,也是可以inplace的。optimizer更新参数时采用的就是这种方法。
import torch
X = torch.tensor([1.0],requires_grad = True)
with torch.no_grad():
print(X.requires_grad)
print(X.is_leaf)
X.add_(2)
"""
输出结果:
True
True
"""
但是要注意此时的结点不可以存在未完成反向传播的计算关系,否则也会报错。
#dZ/dX = 2X
X = torch.tensor([1.0],requires_grad = True)
Z = X*X
with torch.no_grad():
X.add_(2)
Z.backward()
print(X.grad)
"""
RuntimeError: one of the variables needed for gradient computation has been modified by an
inplace operation: [torch.FloatTensor [1]] is at version 1; expected version 0 instead.
Hint: enable anomaly detection to find the operation that failed to compute its gradient,
with torch.autograd.set_detect_anomaly(True).
"""
不过用tensor.data进行inplace操作则不会报错,可以看到此时的梯度计算是错误的。
import torch
#dZ/dX = 2X
X = torch.tensor([1.0],requires_grad = True)
Z = X*X
X.data +=2
Z.backward()
print(X.grad)
"""
tensor([6.])
"""
总结
对tensor进行inplace操作主要有两种方法。
一种是利用add_,mul_等PyTorch自带的inplace计算函数,这种方法的使用是安全的,因为PyTorch内部会进行检查,可以配合torch.no_grad()一起使用。
另一种是利用tensor.data,这种方法是不安全的,因为PyTorch不会此进行检查,可能会导致反传出错。