前言
pytorch
中的原地操作有时候很容易造成一些错误使用的情况,造成非期望的结果而且不容易调试,本文进行一些小记录。 如有谬误请联系指出,本文遵守 CC 4.0 BY-SA 版权协议,转载请联系作者并注明出处,谢谢。
∇ \nabla ∇ 联系方式:
e-mail: FesianXu@gmail.com
github: https://github.com/FesianXu
知乎专栏: 计算机视觉/计算机图形理论与应用
微信公众号:
在pytorch
中存在有很多原地(inplace)操作,比如torch.sigmoid_
和F.relu()
等。在计算&赋值这一个过程中,比如:
x = x+1
y = func(y)
我们对x
和y
进行计算,然后将其重新赋值给了x
和y
,在非原地操作中,我们对旧内存进行计算,然后产生新内存,然后再更新引用,如Fig 1所示。而原地操作中,我们在旧内存上直接更改数值,如Fig 2所示。这种原地操作更加节省内存,但是如果该内存可能被其他变量引用,可能导致计算一致性的问题,存在后效性。考虑到pytorch
中的F.relu
函数或者nn.ReLU(inplace=True)
层,再使用原地操作前,我们要确定其是贯序(Sequential)结构,而不会存在被其他变量引用的情况。使用错误的例子如:
def __init__(self):
self.conv1 = nn.Conv2d(...)
self.conv2 = nn.Conv2d(...)
self.relu = nn.ReLU(inplace=True)
...
def forward(self, x):
x = self.conv1(x)
h = self.relu(x)
h0 = self.conv2(x)
# unexpected error here
此时经过self.relu(x)
之后,x
的内存内容已经被更改了,如果此时去计算self.conv2(x)
,得到的结果很可能不是预期的,此时需要特别注意。这种情况在多分支(Multi-branch)的网络中很常出现。