pytorch中切断网络反向传递的方式最常见的方法就是使用detach来进行的,今天小编就带领小伙伴们来学习一下这种切断网络反传方式吧。
detach
官方文档中,对这个方法是这么介绍的。
detach = _add_docstr(_C._TensorBase.detach, r"""
Returns a new Tensor, detached from the current graph.
The result will never require gradient.
.. note::
Returned Tensor uses the same data tensor as the original one.
In-place modifications on either of them will be seen, and may trigger
errors in correctness checks.
""")
返回一个新的从当前图中分离的 Variable。
返回的 Variable 永远不会需要梯度
如果 被 detach 的Variable volatile=True, 那么 detach 出来的 volatile 也为 True
还有一个注意事项,即:返回的 Variable 和 被 detach 的Variable 指向同一个 tensor
import torch
from torch.nn import init
t1 = torch.tensor([1., 2.],requires_grad=True)
t2 = torch.tensor([2., 3.],requires_grad=True)
v3 = t1 + t2
v3_detached = v3.detach()
v3_detach