[1609.04802] SRGAN中的那些loss

1. loss

深度神经网络模型的设计中,loss绝对要占据一席之位。不同的loss形式,对优化的结果,差别很大。啥也不说,先上代码。

2. 代码

for epoch in range(args.num_epochs):
    for i, data in enumerate(dataloader_train):
        # forward
        x, target = data
        x = Variable(x)
        y_real = Variable(target)
        target_real = Variable(torch.rand(args.batch_size, 1)*0.5 + 0.7)     
        target_fake = Variable(torch.rand(args.batch_size, 1)*0.3)
        y_fake = generator(x)

        # train discriminator
        discriminator.zero_grad()
        discriminator_loss = adversarial_criterion(discriminator(y_real), target_real) + \
                             adversarial_criterion(discriminator(y_fake), target_fake)   
        mean_discriminator_loss += discriminator_loss.data[0]

        discriminator_loss.backward() 
        optim_discriminator.step()

        # train generator 
        generator.zero_grad()

        features_real = Variable(feature_extractor(y_real).data)
        features_fake = Variable(feature_extractor(y_fake))
        generator_content_loss = content_criterion(y_fake, y_real) + \
                                 content_criterion(features_fake, features_real) * 0.006
        mean_generator_content_loss += generator_content_loss.data[0]

        generator_adversarial_loss = adversarial_criterion(discriminator(y_fake), ones_const)
        mean_generator_adversarial_loss += generator_adversarial_loss.data[0]

        generator_total_loss = generator_content_loss + 1e-3*generator_adversarial_loss
        mean_generator_total_loss += generator_total_loss.data[0]

        generator_total_loss.backward()
        optim_generator.step()

典型的pytorch训练循环的代码。代码中出现了discriminator_lossgenerator_content_lossgenerator_adversarial_loss三种loss,第一个用来训练判别器,后两个加起来,训练生成器。

上面计算loss的代码,可视化框图如下
这里写图片描述

训练流程:
- 沿着红色虚线,计算判别损失,更新判别器参数 Dθ
- 沿着粉色虚线,计算产生损失,更新产生器参数 Gθ

3. 和不用GAN的区别

如果不用GAN,模型仅仅用一个G网络,产生y_fake,和y_real求得MSEloss,用这个损失更新网络参数。

而GAN的作用,是额外增加一个D网络和2个损失(判别损失和生成判别损失),用一种交替训练的方式训练两个网络。这个模型可以分为3部分:main模块,adversarial模块,和vgg模块。(一般main模块就是adversarial模块里的G网络)adversarial可以看作是一种训练技巧,只在训练阶段会用到adversarial模块进行计算,而在推断阶段,仅仅使用G网络(或者说main模块)。

也就是任何一个问题,都可以让训练过程“对抗化”。“对抗化”的步骤是:
- 确定main模块(原始问题的解决办法)
- 把main模块当成GAN中的G网络
- 另外增加一个D网络(二分类网络)
- 在原来更新main模块的loss中,增加“生成对抗损失”(要生成让判别器无法区分的数据分布),一起用来更新main模块(也就是GAN中的G网络)
- 用判别损失更新GAN中的D网络

  • 6
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值