下面是gan中的一段代码:[^1]
# process outputs
outputs = self(images, edges, masks)
gen_loss = 0
dis_loss = 0
# discriminator loss
dis_input_real = torch.cat((images, edges), dim=1)
dis_input_fake = torch.cat((images, outputs.detach()), dim=1)
dis_real, dis_real_feat = self.discriminator(dis_input_real) # in: (grayscale(1) + edge(1))
dis_fake, dis_fake_feat = self.discriminator(dis_input_fake) # in: (grayscale(1) + edge(1))
dis_real_loss = self.adversarial_loss(dis_real, True, True)
dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
dis_loss += (dis_real_loss + dis_fake_loss) / 2
# generator adversarial loss
gen_input_fake = torch.cat((images, outputs), dim=1)
gen_fake, gen_fake_feat = self.discriminator(gen_input_fake) # in: (grayscale(1) + edge(1))
gen_gan_loss = self.adversarial_loss(gen_fake, True, False)
gen_loss += gen_gan_loss
为什么在判别器时输入为:torch.cat((images, outputs.detach()), dim=1)?
以下几点需要注意:
1.x=Variable(torch.Tensor([1]),requires_grad=True),如果不加requires_grad=True,默认不对该变量求梯度
2.上面的代码中,dis_input_real的equires_grad=False,经过self.discriminator(dis_input_real)之后requires_grad=True(self.discriminator为判别器网络;outputs = self(images, edges, masks),self为生成器)
3.如果没有 outputs.detach(),虽然会回传到生成器梯度,但是优化器分开进行,其实不会出错。但是outputs.detach()可以加快速度,因为不需要反传所有的梯度。
4.gen_fake, gen_fake_feat = self.discriminator(gen_input_fake),gen中没有detach(),因为阻断了梯度回传,不能回传梯度到gan,这样就训练不了
def backward(self, gen_loss=None, dis_loss=None):
if dis_loss is not None:
dis_loss.backward()
self.dis_optimizer.step()
if gen_loss is not None:
gen_loss.backward()
self.gen_optimizer.step()
self.gen_optimizer = optim.Adam(
params=generator.parameters(),
lr=float(config.LR),
betas=(config.BETA1, config.BETA2)
)
self.dis_optimizer = optim.Adam(
params=discriminator.parameters(),
lr=float(config.LR) * float(config.D2G_LR),
betas=(config.BETA1, config.BETA2)
)