pytorch.detach分离函数

1.官方定义detach

官方解释
Returns a new Tensor, detached from the current graph. The result will never require gradient. 返回一个与当前图分离的新张量。结果永远不需要梯度。

2. 解释

举例说明,假设我们有一个函数 y = x * x ,我们想重新有一个变量 u = y.detach(),此时 u 和 y 同值,我们再定义另外一个函数 z = u * x
y = x 2 (1) y = x^2 \tag{1} y=x2(1)
u = y . d e t a c h ( ) (2) u=y.detach()\tag{2} u=y.detach()(2)

  • u的值 和 y 的值保持一致
    z = u ∗ x (3) z=u*x\tag{3} z=ux(3)
  • 当我们 ∂ z ∂ x \frac{\partial z}{\partial x} xz时,因为分离,所以我们可以将此时的 u 当做常数看待
    ∂ z ∂ x = u \frac{\partial z}{\partial x}=u xz=u
  • 所以此时的梯度是不能够通过 u 进行传播的

3. 代码

# -*- coding: utf-8 -*-
import torch

x = torch.arange(4.0, requires_grad=True)  # 定义 x 
y = x * x  # 定义 y
u = y.detach() # 定义 u,此时的 u的值等于 y,相当于新建一个副本 u 
z = u * x # 此时当我们在 z 对 x 求导的时候,u 被当做一个常量
z.sum().backward()  # 因为 pytorch 中我们是标量对向量求导,所以需要用到 z.sum()
print(f'x={x}')
print(f'y={y}')
print(f'z={z}')
print(f'x.grad={x.grad}')

4. 结果

x=tensor([0., 1., 2., 3.], requires_grad=True)
y=tensor([0., 1., 4., 9.], grad_fn=<MulBackward0>)
z=tensor([ 0.,  1.,  8., 27.], grad_fn=<MulBackward0>)
x.grad=tensor([0., 1., 4., 9.])  #x.grad = u=x**2

5. 小结

detach就是截断反向传播的梯度流,使得变量没有了梯度反向传播。当运算时候 u 看作常数。
注:

  • 自变量 x 中的参数需要为 requires_grad=True
  • 自变量 x 中的值需要为浮点型的值
  • 4
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值