1. loss
深度神经网络模型的设计中,loss绝对要占据一席之位。不同的loss形式,对优化的结果,差别很大。啥也不说,先上代码。
2. 代码
for epoch in range(args.num_epochs):
for i, data in enumerate(dataloader_train):
# forward
x, target = data
x = Variable(x)
y_real = Variable(target)
target_real = Variable(torch.rand(args.batch_size, 1)*0.5 + 0.7)
target_fake = Variable(torch.rand(args.batch_size, 1)*0.3)
y_fake = generator(x)
# train discriminator
discriminator.zero_grad()
discriminator_loss = adversarial_criterion(discriminator(y_real), target_real) + \
adversarial_criterion(discriminator(y_fake), target_fake)
mean_discriminator_loss += discriminator_loss.data[0]
discriminator_loss.backward()
optim_discriminator.step()
# train generator
generator.zero_grad()
features_real = Variable(feature_extractor(y_real).data)
features_fake = Variable(feature_extractor(y_fake))
generator_content_loss = content_criterion(y_fake, y_real) + \
content_criterion(features_fake, features_real) * 0.006
mean_generator_content_loss += generator_content_loss.data[0]
generator_adversarial_loss = adversarial_criterion(discriminator(y_fake), ones_const)
mean_generator_adversarial_loss += generator_adversarial_loss.data[0]
generator_total_loss = generator_content_loss + 1e-3*generator_adversarial_loss
mean_generator_total_loss += generator_total_loss.data[0]
generator_total_loss.backward()
optim_generator.step()
典型的pytorch训练循环的代码。代码中出现了discriminator_loss、generator_content_loss、generator_adversarial_loss三种loss,第一个用来训练判别器,后两个加起来,训练生成器。
上面计算loss的代码,可视化框图如下
训练流程:
- 沿着红色虚线,计算判别损失,更新判别器参数
Dθ
- 沿着粉色虚线,计算产生损失,更新产生器参数
Gθ
3. 和不用GAN的区别
如果不用GAN,模型仅仅用一个G网络,产生y_fake,和y_real求得MSEloss,用这个损失更新网络参数。
而GAN的作用,是额外增加一个D网络和2个损失(判别损失和生成判别损失),用一种交替训练的方式训练两个网络。这个模型可以分为3部分:main模块,adversarial模块,和vgg模块。(一般main模块就是adversarial模块里的G网络)adversarial可以看作是一种训练技巧,只在训练阶段会用到adversarial模块进行计算,而在推断阶段,仅仅使用G网络(或者说main模块)。
也就是任何一个问题,都可以让训练过程“对抗化”。“对抗化”的步骤是:
- 确定main模块(原始问题的解决办法)
- 把main模块当成GAN中的G网络
- 另外增加一个D网络(二分类网络)
- 在原来更新main模块的loss中,增加“生成对抗损失”(要生成让判别器无法区分的数据分布),一起用来更新main模块(也就是GAN中的G网络)
- 用判别损失更新GAN中的D网络