.detach() . + detach_() 解析

这篇博客介绍了PyTorch中detach()和detach_()函数的用途,它们用于在神经网络训练时阻止部分参数的更新。detach()返回一个新的Variable,不参与反向传播,而detach_()直接在原Variable上操作,切断其在计算图中的联系,两者都会设置requires_grad为False。理解这两个函数对于精细化网络训练和控制梯度传播至关重要。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

当我们再训练网络的时候可能

  1. 希望保持一部分的网络参数不变,只对其中一部分的参数进行调整;
  2. 或者只训练部分分支网络,并不让其梯度对主网络的梯度造成影响

这时候我们就需要使用detach()函数来切断一些分支的反向传播

1. detach()

返回一个新的Variable,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个Variable永远不需要计算其梯度,不具有grad即使之后重新将它的requires_grad置为true,它也不会具有梯度grad

例子

此时的模型输入传递关系为:A-->B

反向传播避免某个模型A参数的更新:

# 得到A模型的输出a
a = A(input)
# 将模型A的输出a从当前计算图中分离下来的, 使grad为False
a = detach()

# 得到B模型的输出b
b = B(a)

# 由于A已经被从计算图中分离了,所以这里只更新B的参数
loss = criterion(b, target)
loss.backward()

反向传播避免某个模型B参数的更新:

# 将B的所有grad为False
for param in B.parameters():
	param.requires_grad = False

a = A(input)
b = B(a)
# 只更新A的参数
loss = criterion(b, target)
loss.backward()

2.  detach_()

将一个Variable从创建它的图中分离,并把它设置成叶子variable

其实就相当于变量之间的关系本来是x -> m -> y,这里的叶子variable是x,但是这个时候对m进行了.detach_()操作,其实就是进行了两个操作:

  1. 将m的grad_fn的值设置为None,这样m就不会再与前一个节点x关联,这里的关系就会变成x, m -> y,此时的m就变成了叶子结点
  2. 然后会将m的requires_grad设置为False,这样对y进行backward()时就不会求m的梯度

⚠️这么一看其实detach()和detach_()很像,两个的区别就是detach_()是对本身的更改,detach()则是生成了一个新的variable

比如x -> m -> y中如果对m进行detach(),后面如果反悔想还是对原来的计算图进行操作还是可以的,但是如果是进行了detach_(),那么原来的计算图也发生了变化,就不能反悔了

Pytorch torch.Tensor.detach()方法的用法及修改指定模块权重的方法_torch detach-CSDN博客

pytorch .detach() .detach_() 和 .data用于切断反向传播 - 慢行厚积 - 博客园

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Pengsen Ma

太谢谢了

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值