当我们再训练网络的时候可能
- 希望保持一部分的网络参数不变,只对其中一部分的参数进行调整;
- 或者只训练部分分支网络,并不让其梯度对主网络的梯度造成影响,
这时候我们就需要使用detach()函数来切断一些分支的反向传播
1. detach()
返回一个新的Variable
,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个Variable
永远不需要计算其梯度,不具有grad。即使之后重新将它的requires_grad置为true,它也不会具有梯度grad
例子:
此时的模型输入传递关系为:A-->B
反向传播避免某个模型A参数的更新:
# 得到A模型的输出a
a = A(input)
# 将模型A的输出a从当前计算图中分离下来的, 使grad为False
a = detach()
# 得到B模型的输出b
b = B(a)
# 由于A已经被从计算图中分离了,所以这里只更新B的参数
loss = criterion(b, target)
loss.backward()
反向传播避免某个模型B参数的更新:
# 将B的所有grad为False
for param in B.parameters():
param.requires_grad = False
a = A(input)
b = B(a)
# 只更新A的参数
loss = criterion(b, target)
loss.backward()
2. detach_()
将一个Variable
从创建它的图中分离,并把它设置成叶子variable
其实就相当于变量之间的关系本来是x -> m -> y,这里的叶子variable是x,但是这个时候对m进行了.detach_()操作,其实就是进行了两个操作:
- 将m的grad_fn的值设置为None,这样m就不会再与前一个节点x关联,这里的关系就会变成x, m -> y,此时的m就变成了叶子结点
- 然后会将m的requires_grad设置为False,这样对y进行backward()时就不会求m的梯度
⚠️这么一看其实detach()和detach_()很像,两个的区别就是detach_()是对本身的更改,detach()则是生成了一个新的variable
比如x -> m -> y中如果对m进行detach(),后面如果反悔想还是对原来的计算图进行操作还是可以的,但是如果是进行了detach_(),那么原来的计算图也发生了变化,就不能反悔了
Pytorch torch.Tensor.detach()方法的用法及修改指定模块权重的方法_torch detach-CSDN博客