import torch
t0 = torch.ones((10))
t1 = t0.detach()
print(id(t0),id(t1))
print(id(t0.data) , id(t1.data))
运行结果:
1879016743744 1879016743808
1879016743936 1879016743936
detach()
方法是重新建一个tensort1
,不过t1
和t0
的data
和grad
是共用的。
t0 = torch.ones((10))
t1 = t0.detach_()
print(id(t0),id(t1))
print(id(t0.data) , id(t1.data))
运行结果:
1879016743872 1879016743872
1879016743744 1879016743744
detach_()
方法是内置方法,返回的tensor地址是一样的。
共同点:
两种方法的共同点都有:将返回的tensor从计算图中剥离。
不同点:
不同点:detach()方法可以保留原tensor,用于后续的反向传播。
在生成对抗网络中使用方法:
在生成对抗网络中,我们使用detach()
,因为我们还要保留原tensor用于反向传播,所以不能用detach_()
。