-
pytorch tensor 中的
.clone()和.detach()在写代码时经常能见到通过
tensor.detach().clone()操作生成一个和原本 tensor 值相同的新 tensor为什么需要同时使用
.clone()和.detach(),接下来通过代码进行说明-
生成两个 tensor,并且求梯度
a = torch.tensor([1.0, 1.0], requires_grad=True) b = torch.tensor([2.0, 2.0], requires_grad=True) loss = a@b loss.backward() print(a, b) print(a.grad, b.grad)输出结果:
tensor([1., 1.], requires_grad=True) tensor([2., 2.], requires_grad=True)
tensor([2., 2.]) tensor([1., 1.])可以看到 a, b 的梯度分别为 [2., 2.],[1., 1.]
-
使用 a_=a.detch() 脱离计算图
在上面的代码中加上
<a_=a.detch()并且使用a_计算和 backward()
-
pytorch 中的 .detach() .clone()
PyTorch中.clone()与.detach()的深度解析:梯度传递与内存共享
最新推荐文章于 2024-12-14 22:34:37 发布

最低0.47元/天 解锁文章
2012

被折叠的 条评论
为什么被折叠?



