one=torch.tensor(1,dtype=torch.float32)|
mone=one*-1
moneg=one*-1*gan_loss_percent
D = model_discrination()
#为什么在辨别器反向传播时
d_real = D(real_pair)
d_real = d_real.mean()
d_real.backward(mone)
d_fake = D(fake_pair)
d_fake = d_fake.mean()
d_fake.backward(one)
#生成器反向传播时:
segloss.backward(retain_graph=True)
gd_fake = D(fake_pair)
gd_fake = gd_fake.mean()
gd_fake.backward(moneg)
在生成对抗网络(GAN)的训练过程中,生成器(G)和鉴别器(D)通过反向传播更新权重以达到最优化。训练分为两个主要步骤:训练鉴别器识别真实图像和生成图像,以及训练生成器欺骗鉴别器。代码片段中的反向传播策略是实现这一训练过程的关键部分。我们来详细解释每一步:
1.鉴别器的训练
鉴别器的目标是正确区分出真实图像和生成图像。在训练鉴别器时,会分别处理真实图像对(`real_pair`)和生成图像对(`fake_pair`)。对于每种情况,鉴别器给出一个分数(通过`D(real_pair)`或`D(fake_pair)`),表示图像为真实图像的可能性。
- **真实图像的梯度更新**:
d_