pytorch中detach(), clone()函数和requires_grad_(False)解释

简单不看版:

 detach将当前 tensor 从计算图中分离出来,返回一个和源张量同shape、dtype和device的张量,与源张量共享数据内存(相当于还是指向原tensor),但不提供梯度计算,即requires_grad=False,因此脱离计算图。

clone:返回一个和源张量同shapedtypedevice的张量,与源张量不共享数据内存,但提供梯度的回溯。 

简单版:

>>> import torch
>>> a = torch.tensor(1.0, requires_grad=True)
>>> b = a.clone()
>>> print(a)
tensor(1., requires_grad=True)
>>> print(b)
tensor(1., grad_fn=<CloneBackward0>)
>>> b = a.clone().detach()
>>> print(b)
tensor(1.)

一、detach() 和 requires_grad_(False) 解释

detach() 和 requires_grad_(False) 函数是 PyTorch 中用于解除计算图连接和停止梯度计算的函数。

detach(): 将当前 tensor 从计算图中分离出来,返回一个新的 tensor,新 tensor 的 requires_grad 属性会被设为 False。也就是说,调用 detach() 后,将无法再通过这个 tensor 得到梯度信息,即使后续计算的结果与该 tensor 相关。
requires_grad_(False): 在原地修改 tensor 的 requires_grad 属性为 False,这个 tensor 之前的梯度信息将会被清除。
这两个函数在实际使用中常常用于以下几个方面:

1、在使用 Tensor 进行运算时,可能会产生一些临时结果,这些结果并不需要用于梯度计算,因此需要将这些 tensor 分离出计算图,以减少计算开销和内存消耗。

2、当需要对某些 tensor 进行修改,但不希望这些修改对之前的梯度计算产生影响时,可以使用 detach() 函数,以避免不必要的计算。

3、在使用预训练模型进行微调时,需要冻结一部分参数,不参与梯度更新。此时可以将这些参数的 requires_grad 属性设为 False,避免对其进行不必要的梯度计算。
二、例子

import torch

# 定义需要进行梯度计算的 tensor
x = torch.randn(3, 3, requires_grad=True)
w = torch.randn(3, 3, requires_grad=True)

# 计算 x 和 w 的点积
y = torch.matmul(x, w)

# 对 y 进行一些操作,但不需要计算梯度
z = y.detach().sigmoid()

# 计算 z 的平均值,并反向传播梯度
loss = z.mean()
loss.backward()

# 冻结 w 的梯度,只更新 x 的梯度
w.requires_grad_(False)
x.grad.zero_()
x2 = torch.randn(3, 3, requires_grad=True)
y2 = torch.matmul(x2, w)
loss2 = y2.mean()
loss2.backward()

# 打印梯度信息
print("x 的梯度:", x.grad)
print("w 的梯度:", w.grad)
print("x2 的梯度:", x2.grad)

在这个示例中,我们首先定义了两个需要进行梯度计算的 tensor:x 和 w。然后计算了它们的点积 y,并对 y 进行了一个 sigmoid 操作,得到了 z。由于我们不希望对 z 进行梯度计算,所以使用 detach() 函数将其从计算图中分离出来。接着,我们计算了 z 的平均值 loss,并进行了反向传播。这样就得到了 x 和 w 的梯度。

然后,我们冻结了 w 的梯度,只更新了 x 的梯度。为了说明这一点,我们定义了另外一个 tensor x2,将 x2 和 w 做点积,得到 y2。然后计算了 y2 的平均值 loss2,并进行了反向传播。由于我们冻结了 w 的梯度,所以只有 x2 的梯度被更新,w 的梯度仍然是零。最后打印了三个 tensor 的梯度信息。
                        
原文链接:https://blog.csdn.net/wen_ding/article/details/129558670

有一个总结很好的blog:【python基础】PyTorch中clone()、detach()_pytorch clone-CSDN博客

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值