【GAN】GANLoss之‘vanilla‘, ‘lsgan‘, ‘wgan‘, ‘hinge‘的具体计算方式及实现

说明

由于在实际使用中遇到了多种形式的GANLoss,就整理了以下常用的四种GANLoss在应用中的区别,包括’vanilla’, ‘lsgan’, ‘wgan’, ‘hinge’。

vanilla

2014年由Ian Goodfellow
最普通,最基础的一种形式。采用nn.BCEWithLogitsLoss(),即sigmoid + BCELoss,

self.loss = nn.BCEWithLogitsLoss()

具体计算式:

Ld = -[ylogD(x)+(1-y)log(1-D(G(Z)))]
Lg = -[ylogD(x)+(1-y)log(1-D(G(Z)))]

具体代码测试是如下

from  gan_loss_comps import GANLossComps
import torch.nn as nn
import numpy as np
import torch
import numpy.testing as npt
input_1 = torch.ones(1, 1)
input_2 = torch.ones(1, 3, 6, 6) * 2
gan_loss = GANLossComps(
    'vanilla', loss_weight=2.0, real_label_val=1.0, fake_label_val=0.0) 
#loss = -[ylogy+(1-y)log(1-y)]
#Ld = -[ylogD(x)+(1-y)log(1-D(G(Z)))] 
#Lg = -[ylogD(x)+(1-y)log(1-D(G(Z)))] 
#G
loss = gan_loss(input_1, True, is_disc=False) #-1*(np.log(1/(1+np.exp(-1))))
npt.assert_almost_equal(loss.item(), 0.6265233)
#D
loss = gan_loss(input_1, True, is_disc=True)
npt.assert_almost_equal(loss.item(), 0.3132616)
loss = gan_loss(input_1, False, is_disc=True)
npt.assert_almost_equal(loss.item(), 1.3132616)

lsgan

常规GAN默认的判别器设置是sigmoid交叉熵损失函数训练的分类器。但是,在训练学习过程中这种损失函数的使用可能会导致梯度消失。为了克服这个问题,最小二乘生成对抗网络(LSGAN)采用最小二乘的损失来缓解。实际上,LSGAN的目标函数将本质上是最小化Pearson χ2散度。
与常规GAN相比,LSGAN有两个好处:一是能生成更高质量的图像;二是在训练过程中更稳定。

https://blog.csdn.net/lgzlgz3102/article/details/115475370

self.loss = nn.MSELoss()
## lsgan
input_1 = torch.ones(1, 1)
input_2 = torch.ones(1, 3, 6, 6) * 2

## lsgan
gan_loss = GANLossComps(
         'lsgan', loss_weight=2.0, real_label_val=1.0, fake_label_val=0.0)
         
#Ld = y*(y-D(x))^2 + (1-y)*(y-D(G(x)))^2
#Lg = (y-D(G(x)))^2

#gen
loss = gan_loss(input_2, True, is_disc=False) #loss = 2*(y-D(x))^2 = 2* 1^2 = 2 
npt.assert_almost_equal(loss.item(), 2.0)
#dis
loss = gan_loss(input_2, True, is_disc=True) #loss = (y-D(x))^2 =  1^2 = 1 
npt.assert_almost_equal(loss.item(), 1.0)
loss = gan_loss(input_2, False, is_disc=True) #loss = (y-D(x))^2 =  2^2 = 4 
npt.assert_almost_equal(loss.item(), 4.0)

wgan

Wasserstein GAN
解决的问题:
模式崩溃,生成器生成非常窄的分布,仅覆盖数据分 布中的单一模式。 模式崩溃的含义是生成器只能生成非常相似的样本(例如 ,MNIST中的单个数字),即生成的样本不是多样的。
没有指标可以告诉我们收敛情况。生成器和判别器的 loss并没有告诉我们任何收敛相关信息。当然,我们可以通 过不时地查看生成器生成的数据来监控训练进度。但是, 这是一个手动过程。因此,我们需要有一个可解释的指标 可以告诉我们有关训练的进度。

https://blog.csdn.net/m0_62128864/article/details/124258797
https://zhuanlan.zhihu.com/p/361808267

 self.loss = -input.mean() if target else input.mean()

具体计算式:

loss = -[ylogD(x)+(1-y)log(1-D(G(Z)))] 修改为 -[yD(x)-(1-y)(D(G(Z)))] #去掉log,(1-D(G(Z)) 换成-D(G(Z)

# wgan
gan_loss = GANLossComps(
    'wgan', loss_weight=2.0, real_label_val=1.0, fake_label_val=0.0)

#Ld = -[yD(x)-(1-y)(D(G(Z)))]
#Lg = - yD(x)

loss = gan_loss(input_2, True, is_disc=False) #-2*(2)=-4
npt.assert_almost_equal(loss.item(), -4.0)

loss = gan_loss(input_2, True, is_disc=True)
npt.assert_almost_equal(loss.item(), -2.0)
loss = gan_loss(input_2, False, is_disc=True)
npt.assert_almost_equal(loss.item(), 2.0)

hinge

对于D来说,只有当D(x) < 1 的正向样本,以及D(G(z)) > -1的负样本才会对结果产生影响
也就是说,只有一些没有被合理区分的样本,才会对梯度产生影响

https://zh.wikipedia.org/zh-cn/Hinge_loss
https://zhuanlan.zhihu.com/p/72195907

self.loss = nn.ReLU()

具体计算式:

Ld = E(max(0,1-D(x)))+E(max(0,1+D(G(z))))
Lg = -E(D(G(z)))

# hinge
gan_loss = GANLossComps(
    'hinge', loss_weight=2.0, real_label_val=1.0, fake_label_val=0.0)
# G
loss = gan_loss(input_2, True, is_disc=False) #跟wgan一样直接输出-input.mean()
npt.assert_almost_equal(loss.item(), -4.0)

# D
loss = gan_loss(input_2, True, is_disc=True) 
npt.assert_almost_equal(loss.item(), 0.0)
loss = gan_loss(input_2, False, is_disc=True)
npt.assert_almost_equal(loss.item(), 3.0)

总结

vanilla :sigmoid + BCELoss
lsgan : MSE
wgan : 去掉log,(1-D(G(Z)) 换成-D(G(Z),限制L
hinge: 限制E(max(0,1-D(x)))+E(max(0,1+D(G(z))))

附录

文中使用到的GANLossComps类,作为附录传在下方

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union

import torch
import torch.nn as nn
import torch.nn.functional as F

class GANLossComps(nn.Module):
    """Define GAN loss.

    Args:
        gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge',
            'wgan-logistic-ns'.
        real_label_val (float): The value for real label. Default: 1.0.
        fake_label_val (float): The value for fake label. Default: 0.0.
        loss_weight (float): Loss weight. Default: 1.0.
            Note that loss_weight is only for generators; and it is always 1.0
            for discriminators.
    """

    def __init__(self,
                 gan_type: str,
                 real_label_val: float = 1.0,
                 fake_label_val: float = 0.0,
                 loss_weight: float = 1.0) -> None:
        super().__init__()
        self.gan_type = gan_type
        self.loss_weight = loss_weight
        self.real_label_val = real_label_val
        self.fake_label_val = fake_label_val

        if self.gan_type == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif self.gan_type == 'lsgan':
            self.loss = nn.MSELoss()
        elif self.gan_type == 'wgan':
            self.loss = self._wgan_loss
        elif self.gan_type == 'wgan-logistic-ns':
            self.loss = self._wgan_logistic_ns_loss
        elif self.gan_type == 'hinge':
            self.loss = nn.ReLU()
        else:
            raise NotImplementedError(
                f'GAN type {self.gan_type} is not implemented.')

    def _wgan_loss(self, input: torch.Tensor, target: bool) -> torch.Tensor:
        """wgan loss.

        Args:
            input (Tensor): Input tensor.
            target (bool): Target label.

        Returns:
            Tensor: wgan loss.
        """
        return -input.mean() if target else input.mean()

    def _wgan_logistic_ns_loss(self, input: torch.Tensor,
                               target: bool) -> torch.Tensor:
        """WGAN loss in logistically non-saturating mode.

        This loss is widely used in StyleGANv2.

        Args:
            input (Tensor): Input tensor.
            target (bool): Target label.

        Returns:
            Tensor: wgan loss.
        """

        return F.softplus(-input).mean() if target else F.softplus(
            input).mean()

    def get_target_label(self, input: torch.Tensor,
                         target_is_real: bool) -> Union[bool, torch.Tensor]:
        """Get target label.

        Args:
            input (Tensor): Input tensor.
            target_is_real (bool): Whether the target is real or fake.

        Returns:
            (bool | Tensor): Target tensor. Return bool for wgan, otherwise, \
                return Tensor.
        """

        if self.gan_type in ['wgan', 'wgan-logistic-ns']:
            return target_is_real
        target_val = (
            self.real_label_val if target_is_real else self.fake_label_val)
        return input.new_ones(input.size()) * target_val

    def forward(self,
                input: torch.Tensor,
                target_is_real: bool,
                is_disc: bool = False) -> torch.Tensor:
        """
        Args:
            input (Tensor): The input for the loss module, i.e., the network
                prediction.
            target_is_real (bool): Whether the targe is real or fake.
            is_disc (bool): Whether the loss for discriminators or not.
                Default: False.

        Returns:
            Tensor: GAN loss value.
        """
        target_label = self.get_target_label(input, target_is_real)
        if self.gan_type == 'hinge':
            if is_disc:  # for discriminators in hinge-gan
                input = -input if target_is_real else input
                loss = self.loss(1 + input).mean()
            else:  # for generators in hinge-gan
                loss = -input.mean()
        else:  # other gan types
            loss = self.loss(input, target_label)

        # loss_weight is always 1.0 for discriminators
        return loss if is_disc else loss * self.loss_weight

  • 2
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

RichardCV

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值