【pytorch】tensor的复制避坑;tensor.clone() & tensor.detach() 详解

问题来源:

// A code block
a = torch.zeros(2, 3)
for i in range(2):
for j in range(3):
b = a.data
b[i, j] = 1
print("b:", b)
print("a:", a)

我原以为 b = a.data 就是开辟一个新空间给 b变量,然后修改 b 的值时 a 的值不会因此变化,谁知道即
使用了.data属性,修改 b 的值后 a 的值依然会发生变化。
out:

b: tensor([[1., 0., 0.],
[0., 0., 0.]])
a: tensor([[1., 0., 0.],
[0., 0., 0.]])
b: tensor([[1., 1., 0.],
[0., 0., 0.]])
a: tensor([[1., 1., 0.],
[0., 0., 0.]])
b: tensor([[1., 1., 1.],
[0., 0., 0.]])
a: tensor([[1., 1., 1.],
[0., 0., 0.]])
b: tensor([[1., 1., 1.],
[1., 0., 0.]])
a: tensor([[1., 1., 1.],
[1., 0., 0.]])
b: tensor([[1., 1., 1.],
[1., 1., 0.]])
a: tensor([[1., 1., 1.],
[1., 1., 0.]])
b: tensor([[1., 1., 1.],
[1., 1., 1.]])
a: tensor([[1., 1., 1.],
[1., 1., 1.]])

后来上网查找,发现无论是用 b = a,还是 b = a.data,变量a和b的关系都没有断开,下面介绍两种函
数,请注意区分!

最保险的办法:

既脱离了计算图也避免了内存共享。

a.detach().clone()

1:tensor.clone()

clone()函数可以返回一个完全相同的tensor,新的tensor开辟新的内存,但是仍然留在计算图中,而且这
个新tensor不能再作为叶子节点参加梯度运算。
例如:

a = torch.tensor([1,2,3], requires_grad=True, dtype=torch.float32)
c = a.clone()
b = (a**2).sum()
b.backward()
print(a) # tensor([1., 2., 3.], requires_grad=True)
print(c) # tensor([1., 2., 3.], grad_fn=<CloneBackward>)
print(a.grad) # tensor([2., 4., 6.])
print(c.grad) # None
print(a.requires_grad) # True
print(c.requires_grad) # True
a = 1
print(c) #tensor([1., 2., 3.], grad_fn=<CloneBackward>) c的值并没有
改变

我们可以发现,clone()出来的新变量和原来的变量没有任何关系,但是如果原来变量a的
requires_grad=True,那么clone()出来的变量c的 requires_grad=True,但是两者梯度没有任何关系。

2:tensor.detach()

detach()函数可以返回一个完全相同的tensor,新的tensor开辟与旧的tensor共享内存,新的tensor会脱
离计算图,不会牵扯梯度计算。此外,一些原地操作(in-place, such as resize_ / resize_as_ / set_ /
transpose_) 在两者任意一个执行都会引发错误。
看例子:

a = torch.tensor([1,2,3], requires_grad=True, dtype=torch.float32)
c = a.detach()
b = (a**2).sum()
b.backward()
print(a) # tensor([1., 2., 3.], requires_grad=True)
print(c) # tensor([1., 2., 3.])
print(a.grad) # tensor([2., 4., 6.])
print(c.grad) # None
print(a.requires_grad) # True
print(c.requires_grad) # False
a = 1
print(c) #tensor([1., 2., 3.])

我们可以发现,detach()不会把a和c的联系切断,改变c的值a也会改变,只是把新变量的 requires_grad
变成了 False。总而言之,a和c还是共享同一块内存单元。
但是要注意:

a = torch.ones(2,2)
a.requires_grad = True
c = a.detach()
#a[1,2] = 0
print(c) #tensor([1., 2., 3.])
print(a)
b = (a*2).sum()
b.backward()
print(a.grad)
a = a - a.grad #注意这里的 = 号给 a 分配了新内存,所以 a = 会破坏原来的 a c 联系生成新内
存单元下的a
print(a)
print(c)

out:

tensor([[1., 1.],
[1., 1.]])
tensor([[1., 1.],
[1., 1.]], requires_grad=True)
tensor([[2., 2.],
[2., 2.]])
tensor([[-1., -1.],
[-1., -1.]], grad_fn=<SubBackward0>)
tensor([[1., 1.],
[1., 1.]])

还有:

a = torch.ones(2,2)
a.requires_grad = True
c = a
a = a + 1 # 新变量 a 和 c 的联系断了(不在一个内存单元了)
print(a)
print(c)

out:

tensor([[2., 2.],
[2., 2.]], grad_fn=<AddBackward0>)
tensor([[1., 1.],
[1., 1.]], requires_grad=True)

参考:https://blog.csdn.net/winycg/article/details/100813519

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值