Pytorch clone() detach()

1.clone()主要用于模块复用    数据进行复制,不共享同一内存,梯度可以回溯

c=torch.tensor(1.0,requires_grad=True)
b=c*2
d=b**2  (**)

b_=b.clone()
e_=b_**3
e_.backward(retain_graph=True)
"""
b.zero_()  这里的b是d.backward()的回溯节点(**),在回溯前不能进行in place 操作,
目的保证梯度计算正确,但如果是b_.zero_()就不会报错,因为clone不共享内存
"""
d.backward()
print(c.grad)  #tensor(32.)

 这里单独查看b_.grad或者b.grad都不存在,因为他们是中间变量,不需要保存,更新也是只更新叶子节点,此外要设置retain_graph=True,因为有一条线路上先进行了梯度回溯,为节省显存计算图会释放。

2.detach()主要用于数据的提取,共享同一内存,强制require_grad=False(即使设置为True也不进行梯度回溯)

c=torch.tensor(1.0,requires_grad=True)
b=c*2
w=b**2

b_=b.detach()
q=torch.tensor(1.0,requires_grad=True)
e_=q**b_
e_.backward()

#b_.zero_()  因为detach共享内存,这里进行in palce操作会报错
w.backward()
print(q.grad)  #tensor(2.)

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值