Patch GAN的理解

判别器的设计

CycleGAN网络中的判别器使用的是一种叫“PatchGAN”的设计,原始GAN的discriminator的设计是仅输出一个评价值(True or False),该值是对生成器生成的整幅图像的一个评价。而PatchGAN的设计不同,PatchGAN设计成全卷积的形式(这也就是为啥上文中最后作者说patchgan可以叫做 fully convolutional GAN),图像经过各种卷积层后,并不会输入到全连接层或者激活函数中,而是使用卷积将输入映射为NN矩阵,该矩阵等同于原始GAN中的最后的评价值用以评价生成器的生成图像。NN矩阵中每个点(true or false)即代表原始图像中的一块小区域(这也就是patch含义)评价值,这也就是“感受野(下图)”的应用。原来用一个值衡量整幅图,现在使用NN的矩阵来评价整幅图(使用patchgan标签也需要设置成为NN的格式,这样就可以进行损失计算了),显然后者可以关注更多的区域,这也就是patchgan的优势。

在这里插入图片描述
这部分是知乎的回答哈, 链接在这

也就是判别器是一些卷积层堆叠而成的, 最后的特征图归为一个。

lass NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)
#此处网络结构
"""[Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)), LeakyReLU(negative_slope=0.2, inplace=True), 
Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)), InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False), LeakyReLU(negative_slope=0.2, inplace=True),
 Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)), InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False), LeakyReLU(negative_slope=0.2, inplace=True),
 Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1)), InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False), LeakyReLU(negative_slope=0.2, inplace=True),
 Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))]"""

损失函数的设计

  • 但是这里的损失怎么用呢, 很多文章都没有, 于是去翻了翻源码, 这里其实是个MSE loss, 也就是L2 loss

  • 首先是loss的组成

 # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

可以看到criterionGAN就是用了Pathch GAN的损失计算

class GANLoss(nn.Module):
   ## 这里注意的是需要创建和输入大小一致的张量

    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
        """ Initialize the GANLoss class.

       cycle gan用的是参数lsgan
        """
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode in ['wgangp']:
            self.loss = None
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)

    def get_target_tensor(self, prediction, target_is_real):
        """Create label tensors with the same size as the input.
   # 这里看你的输入, 如果是正类就创建和输入size一样的, 全是1, 否则全是0, 这里用了expand_as来实现
        """

        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(prediction)

    def __call__(self, prediction, target_is_real):

		# 类似forward函数
        """Calculate loss given Discriminator's output and grount truth labels.

        Parameters:
            prediction (tensor) - - tpyically the prediction output from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            the calculated loss.
        """
        if self.gan_mode in ['lsgan', 'vanilla']:
            target_tensor = self.get_target_tensor(prediction, target_is_real)

            loss = self.loss(prediction, target_tensor)
        elif self.gan_mode == 'wgangp':
            if target_is_real:
                loss = -prediction.mean()
            else:
                loss = prediction.mean()
        return loss
  • 9
    点赞
  • 41
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

live_for_myself

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值