如图是pytorch采用计算图来求解线性方程:
y
(
h
,
x
)
=
W
h
∗
h
+
W
x
∗
x
y(h,x)=W_{h}*h+W_{x} *x
y(h,x)=Wh∗h+Wx∗x其中‘→’的方向为反向传播的方向。然而一般情况下当反向传播backward()
结束时代表计算图的一次迭代就结束了,此时计算图会自动free掉。但在我们的实验过程中,常常需要设计复杂的损失函数以取得我们所需要的显著的实验效果。如下图:
两个损失函数是截然不同的两类损失函数,因此我们可以通过代码:backward(retain_graph=True)
在计算出第一个损失函数的梯度值后保存计算图用于继续计算第二个损失函数的梯度。
pytorch的计算图loss.backward(retain_graph=True)实现复杂loss的训练
最新推荐文章于 2024-05-28 11:30:31 发布