『PyTorch』tensor.detach()和tensor.data的区别以及就地操作


tensor.detach()和tensor.data的区别


结论先行

推荐使用.detach()实现数据的分离


相同点:

a.data, a.detach() 为例:

两种方法均会返回和a相同的tensor,且与原tensor a 共享数据,一方改变,则另一方也改变。

所起的作用均是将变量tensor从原有的计算图中分离出来,分离所得tensor的requires_grad = False

这么做通常是为了切断某一些前向传播分支的梯度反向传播


不同点:

  • .data是一个属性,.detach()是一个方法;
  • .data是不安全的,.detach()是安全的;
import torch
 
a = torch.tensor([1,2,3.], requires_grad = True)
out = a.sigmoid()

# 通过.data分离后与原变量共享数据
# 分离所得变量requires_grad = False
data_c = out.data

# 使用就地操作
# 一方改变 则另一方改变
# out改变 梯度计算出错
data_c.zero_()

# 不使用就地操作 不会报错 data_c未改变 梯度计算正常
# d = data_c.add(torch.ones_like(data_c))

print(data_c.requires_grad)
print(data_c)
print(out.requires_grad)
print(out)

# 对输出结果out求导
out.sum().backward()
# 不会报错 但是结果不正确 因此.data不安全
print("梯度:")
print(a.grad)  
print("----------------------------------------------")

为什么.data是不安全的?

这是因为,当我们修改分离后的tensor,从而导致原tensora发生改变。PyTorch的自动求导Autograd是无法捕捉到这种变化的,会依然按照求导规则进行求导,导致计算出错误的导数值。

其风险性在于,如果我在某一处修改了某一个变量,求导的时候也无法得知这一修改,可能会在不知情的情况下计算出错误的导数值。


import torch

a = torch.tensor([1,2,3.], requires_grad = True)
out = a.sigmoid()

# 通过.detach分离后与原变量共享数据
# 分离所得变量requires_grad = False
detach_c = out.detach()

# 使用就地操作
# 一方改变 则另一方改变
# out改变 导致报错
detach_c.zero_()

# 不使用就地操作 不会报错 detach_c 梯度计算正常
# d = detach_c.add(torch.ones_like(detach_c))

print(detach_c.requires_grad)
print(detach_c)
print(out.requires_grad)
print(out)

# 对输出结果out求导
out.sum().backward()
# 报错 提醒你已有变量由于inplace operation发生了修改 因此.detach安全
print("梯度:")
print(a.grad)
print("----------------------------------------------")

那么.detach()为什么是安全的?

使用.detach()的好处在于,若是出现上述情况,Autograd可以检测出某一处变量已经发生了改变,进而以如下形式报错,从而避免了错误的求导。

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

从以上可以看出,是在前向传播的过程中使用就地操作(In-place operation)导致了这一问题,那么就地操作是什么呢?


就地操作(In-place operation)

就地操作(In-place operation),在PyTorch中是指改变一个tensor变量的值时,不需要开创新的内存空间以进行复制操作,而是直接在原变量地址上进行修改。

其好处在于,可以节省内存空间

在PyTorch中以方法名接后缀“_”的方式来代表就地操作,比如.normal_() .add_() .scatter_()。值得一提的是,Python中的x += res也属于就地操作,而x = x + res则不属于。

然而,在PyTorch的Autograd机制下使用就地操作是一件棘手的事情,在大多数情况下不推荐使用就地操作。

Autograd积极的缓冲区释放和重用使其非常高效,很少会有就地操作显著降低内存使用的情况。除非是在内存压力非常大的情况下操作,否则可能永远都不需要使用它们。

那么,就地操作会带来哪些弊端呢?

  • 可能会覆盖计算梯度所需的值,这意味着破坏了模型的训练过程。
  • 每个就地操作都需要实现重写计算图。异地操作Out-of-place只是简单地分配新对象并保留对旧计算图的引用,而就地操作则需要将所有输入的创建者更改为表示此操作的Function。当修改后的输入也被其他张量引用时(比如共享数据的情况),就会引发错误。

所以避免这些问题的最好方法是,尽量不使用就地操作


另外,根据报错所给提示,可通过:

with torch.autograd.set_detect_anomaly(True):
        ......

检测就地操作报错的具体位置,详情可参考Github上的相关issue


参考

pytorch中的.detach和.data深入详解_LoveMIss-Y的博客-CSDN博客_pytorch的data属性

Autograd mechanics - PyTorch 1.13 documentation

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值