需求:在生成对抗网络中,会有生成器和判别器分别训练的过程,为了避免一些不必要的计算,我们使用detach阻隔反向传播。
上图所示,完整训练一次的过程包括了判别器和生成器的训练
判别器的训练:
#判别器预测
pred_fake = net_d.forward(fake_ab.detach())
pred_real = net_d.forward(real_ab)
#判别器损失
loss_d_fake = criterionGAN(pred_fake, False)
loss_d_real = criterionGAN(pred_real, True)
loss_d = (loss_d_fake + loss_d_real) * 0.5
#判别器迭代优化
optimizer_d.zero_grad()
loss_d.backward()
optimizer_d.step()
注意,
1)未使用fake_ab.detach()
loss_d.backward()是对所有的变量,包括生成器,判别器变量,都计算梯度
2)使用fake_ab.detach()
loss_d.backward()是对所有的判别器变量,计算梯度,因为detach()隔断了梯度计算
3)optimizer_d.step()
optimizer_d仅仅对判别器变量进行梯度更新
结论:
1)detach()的目的仅仅是减小了计算梯度的变量个数,加速了训练过程。
2)生成器的训练就没办法使用detach(),只能将判别器的参数梯度也计算一遍,显然没必要。不过使用了optim进行了特定参数更新,梯度下降。