pytorch 中的 .detach() .clone()

PyTorch中.clone()与.detach()的深度解析:梯度传递与内存共享
  • pytorch tensor 中的 .clone().detach()

    detach() 的用法

    在写代码时经常能见到通过 tensor.detach().clone() 操作生成一个和原本 tensor 值相同的新 tensor

    为什么需要同时使用 .clone().detach() ,接下来通过代码进行说明

    1. 生成两个 tensor,并且求梯度

      a = torch.tensor([1.0, 1.0], requires_grad=True)
      b = torch.tensor([2.0, 2.0], requires_grad=True)
      loss = a@b
      loss.backward()
      print(a, b)
      print(a.grad, b.grad)
      

      输出结果:

      tensor([1., 1.], requires_grad=True) tensor([2., 2.], requires_grad=True)
      tensor([2., 2.]) tensor([1., 1.])

      可以看到 a, b 的梯度分别为 [2., 2.],[1., 1.]

    2. 使用 a_=a.detch() 脱离计算图

      在上面的代码中加上 a_=a.detch() 并且使用 a_ 计算和 backward()

      <
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值