一、detach()
和 requires_grad_(False)
解释
detach()
和 requires_grad_(False)
函数是 PyTorch
中用于解除计算图连接和停止梯度计算的函数。
detach():
将当前 tensor 从计算图中分离出来,返回一个新的 tensor,新 tensor 的 requires_grad 属性会被设为 False。也就是说,调用 detach() 后,将无法再通过这个 tensor 得到梯度信息,即使后续计算的结果与该 tensor 相关。
requires_grad_(False):
在原地修改 tensor 的 requires_grad 属性为 False,这个 tensor 之前的梯度信息将会被清除。
这两个函数在实际使用中常常用于以下几个方面:
1、在使用 Tensor 进行运算时,可能会产生一些临时结果,这些结果并不需要用于梯度计算,因此需要将这些 tensor 分离出计算图,以减少计算开销和内存消耗。
2、当需要对某些 tensor 进行修改,但不希望这些修改对之前的梯度计算产生影响时,可以使用 detach() 函数,以避免不必要的计算。
3、在使用预训练模型进行微调时,需要冻结一部分参数,不参与梯度更新。此时可以将这些参数的 requires_grad 属性设为 False,避免对其进行不必要的梯度计算。
二、例子
import torch
# 定义需要进行梯度计算的 tensor
x = torch.randn(3, 3, requires_grad=True)
w = torch.randn(3, 3, requires_grad=True)
# 计算 x 和 w 的点积
y = torch.matmul(x, w)
# 对 y 进行一些操作,但不需要计算梯度
z = y.detach().sigmoid()
# 计算 z 的平均值,并反向传播梯度
loss = z.mean()
loss.backward()
# 冻结 w 的梯度,只更新 x 的梯度
w.requires_grad_(False)
x.grad.zero_()
x2 = torch.randn(3, 3, requires_grad=True)
y2 = torch.matmul(x2, w)
loss2 = y2.mean()
loss2.backward()
# 打印梯度信息
print("x 的梯度:", x.grad)
print("w 的梯度:", w.grad)
print("x2 的梯度:", x2.grad)
在这个示例中,我们首先定义了两个需要进行梯度计算的 tensor:x 和 w。然后计算了它们的点积 y,并对 y 进行了一个 sigmoid 操作,得到了 z。由于我们不希望对 z 进行梯度计算,所以使用 detach() 函数将其从计算图中分离出来。接着,我们计算了 z 的平均值 loss,并进行了反向传播。这样就得到了 x 和 w 的梯度。
然后,我们冻结了 w 的梯度,只更新了 x 的梯度。为了说明这一点,我们定义了另外一个 tensor x2,将 x2 和 w 做点积,得到 y2。然后计算了 y2 的平均值 loss2,并进行了反向传播。由于我们冻结了 w 的梯度,所以只有 x2 的梯度被更新,w 的梯度仍然是零。最后打印了三个 tensor 的梯度信息。