1.官方定义detach
官方解释
Returns a new Tensor, detached from the current graph. The result will never require gradient. 返回一个与当前图分离的新张量。结果永远不需要梯度。
2. 解释
举例说明,假设我们有一个函数 y = x * x ,我们想重新有一个变量 u = y.detach(),此时 u 和 y 同值,我们再定义另外一个函数 z = u * x
y
=
x
2
(1)
y = x^2 \tag{1}
y=x2(1)
u
=
y
.
d
e
t
a
c
h
(
)
(2)
u=y.detach()\tag{2}
u=y.detach()(2)
- u的值 和 y 的值保持一致
z = u ∗ x (3) z=u*x\tag{3} z=u∗x(3) - 当我们
∂
z
∂
x
\frac{\partial z}{\partial x}
∂x∂z时,因为分离,所以我们可以将此时的 u 当做常数看待
∂ z ∂ x = u \frac{\partial z}{\partial x}=u ∂x∂z=u - 所以此时的梯度是不能够通过 u 进行传播的
3. 代码
# -*- coding: utf-8 -*-
import torch
x = torch.arange(4.0, requires_grad=True) # 定义 x
y = x * x # 定义 y
u = y.detach() # 定义 u,此时的 u的值等于 y,相当于新建一个副本 u
z = u * x # 此时当我们在 z 对 x 求导的时候,u 被当做一个常量
z.sum().backward() # 因为 pytorch 中我们是标量对向量求导,所以需要用到 z.sum()
print(f'x={x}')
print(f'y={y}')
print(f'z={z}')
print(f'x.grad={x.grad}')
4. 结果
x=tensor([0., 1., 2., 3.], requires_grad=True)
y=tensor([0., 1., 4., 9.], grad_fn=<MulBackward0>)
z=tensor([ 0., 1., 8., 27.], grad_fn=<MulBackward0>)
x.grad=tensor([0., 1., 4., 9.]) #x.grad = u=x**2
5. 小结
detach就是截断反向传播的梯度流,使得变量没有了梯度反向传播。当运算时候 u 看作常数。
注:
- 自变量 x 中的参数需要为 requires_grad=True
- 自变量 x 中的值需要为浮点型的值