我认为所谓的.detcah()就是两步操作:
- copy() 将原来的变量复制一份
- variable---->tensor 将复制品变为Tensor
x = Variable(t.rand(1,2,3), requires_grad=True)
y = Variable(t.rand(1,2,3), requires_grad=True)
z = (x.detach()*y)
print('1', x)
print('2', y)
z = y**2
t.sum(z).backward()
print('3', y.grad.data)
print('4', x.grad)
z1 = (x*y)
t.sum(z1).backward()
print('5', x)
print('6', x.grad.data)
输出如下:
1 tensor([[[0.1315, 0.0035, 0.1518],
[0.3182, 0.0520, 0.9082]]], requires_grad=True)
2 tensor([[[0.3761, 0.5830, 0.3036],
[0.3458, 0.4926, 0.3941]]], requires_grad=True)
3 tensor([[[0.7522, 1.1660, 0.6071],
[0.6916, 0.9852, 0.7882]]])
4 None
5 tensor([[[0.1315, 0.0035, 0.1518],
[0.3182, 0.0520, 0.9082]]], requires_grad=True)
6 tensor([[[0.3761, 0.5830, 0.3036],
[0.3458, 0.4926, 0.3941]]])
- 可以看到detach()复制品之后的梯度还是正常的反向传播,但是之前的梯度因为detach()复制品本身是Tensor,自然就不可能传回去了,而且x.detach()和x都不是一个东西了,x.detcah()相当于新的叶节点。
- 既然是copy(), x.detcah()的操作对于x没有一点影响,x想怎么浪还是怎么浪,x.detcah()虽然是x的阉割版,x本身依旧坚挺,反向传播还是可以通过x的。