detach() 函数是 PyTorch 中 Tensor 对象的方法之一,其作用是创建一个新的 Tensor,该 Tensor 和原始的 Tensor 共享相同的存储空间,但是和计算图断开连接,不再参与梯度计算。换句话说,detach() 函数可以用来获取一个 Tensor 的副本,但是该副本不再与计算图相关联,因此它不会影响到反向传播过程。
通常情况下,detach() 函数在需要将某个 Tensor 作为新的输入传递给某个函数或者模型时很有用。例如,在使用已经训练好的模型进行推断时,可能希望得到模型输出的副本,而不希望这个输出对模型的参数进行梯度计算。这时就可以使用 detach() 函数。
下面是一个示例:
import torch
# 定义一个需要梯度计算的 Tensor
x = torch.tensor([2.0], requires_grad=True)
# 计算 y = x^2
y = x**2
# 使用 detach() 获取 y 的副本,但是不再与计算图相关联
y_detached = y.detach()
# 对 y_detached 进行操作,不会影响到原始的 Tensor y
z = y_detached + 1
# 对 z 进行反向传播
z.backward()
# 输出 x 的梯度,此时 x.grad = 4,因为 z = y_detached + 1,而 y_detached = x^2,所以 z 对 x 的梯度是 2
print(x.grad)
在上面的示例中,y_detached 是通过 detach() 函数得到的 y 的副本,它和 y 共享相同的存储空间,但是不再与计算图相关联。因此,对 y_detached 进行操作不会影响到原始的 Tensor y,同时,z 对 x 的梯度计算也不会影响到 y,所以最终输出的 x.grad 为 4。