什么是PyTorch.detach()
方法?
PyTorch的分离方法适用于张量类。
tensor.detach()
创建一个与不需要梯度的张量共享存储的张量。tensor.clone()
创建一个仿照原张量的张量副本requires_grad
场。
你应该用detach()
当试图从计算图中删除张量时,克隆作为复制张量的一种方式,同时仍将复制作为来自计算图的一部分。
让我们在这里的一个例子中看到
X = torch.ones((28, 28), dtype=torch.float32, requires_grad=True)
y = X**2
z = X**2
result = (y+z).sum()
torchviz.make_dot(result).render('Attached', format='png')
现在有了分离。
X = torch.ones((28, 28), dtype=torch.float32, requires_grad=True)
y = X**2
z = X.detach()**2
result = (y+z).sum()
torchviz.make_dot(result).render('Attached', format='png')
正如您现在可以看到的那样,计算的分支x**2
不再被追踪。这反映在不再记录此分支的贡献的结果的梯度中。