说明
由于在实际使用中遇到了多种形式的GANLoss,就整理了以下常用的四种GANLoss在应用中的区别,包括’vanilla’, ‘lsgan’, ‘wgan’, ‘hinge’。
vanilla
2014年由Ian Goodfellow
最普通,最基础的一种形式。采用nn.BCEWithLogitsLoss(),即sigmoid + BCELoss,
self.loss = nn.BCEWithLogitsLoss()
具体计算式:
Ld = -[ylogD(x)+(1-y)log(1-D(G(Z)))]
Lg = -[ylogD(x)+(1-y)log(1-D(G(Z)))]
具体代码测试是如下
from gan_loss_comps import GANLossComps
import torch.nn as nn
import numpy as np
import torch
import numpy.testing as npt
input_1 = torch.ones(1, 1)
input_2 = torch.ones(1, 3, 6, 6) * 2
gan_loss = GANLossComps(
'vanilla', loss_weight=2.0, real_label_val=1.0, fake_label_val=0.0)
#loss = -[ylogy+(1-y)log(1-y)]
#Ld = -[ylogD(x)+(1-y)log(1-D(G(Z)))]
#Lg = -[ylogD(x)+(1-y)log(1-D(G(Z)))]
#G
loss = gan_loss(input_1, True, is_disc=False) #-1*(np.log(1/(1+np.exp(-1))))
npt.assert_almost_equal(loss.item(), 0.6265233)
#D
loss = gan_loss(input_1, True, is_disc=True)
npt.assert_almost_equal(loss.item(), 0.3132616)
loss = gan_loss(input_1, False, is_disc=True)
npt.assert_almost_equal(loss.item(), 1.3132616)
lsgan
常规GAN默认的判别器设置是sigmoid交叉熵损失函数训练的分类器。但是,在训练学习过程中这种损失函数的使用可能会导致梯度消失。为了克服这个问题,最小二乘生成对抗网络(LSGAN)采用最小二乘的损失来缓解。实际上,LSGAN的目标函数将本质上是最小化Pearson χ2散度。
与常规GAN相比,LSGAN有两个好处:一是能生成更高质量的图像;二是在训练过程中更稳定。
https://blog.csdn.net/lgzlgz3102/article/details/115475370
self.loss = nn.MSELoss()
## lsgan
input_1 = torch.ones(1, 1)
input_2 = torch.ones(1, 3, 6, 6) * 2
## lsgan
gan_loss = GANLossComps(
'lsgan', loss_weight=2.0, real_label_val=1.0, fake_label_val=0.0)
#Ld = y*(y-D(x))^2 + (1-y)*(y-D(G(x)))^2
#Lg = (y-D(G(x)))^2
#gen
loss = gan_loss(input_2, True, is_disc=False) #loss = 2*(y-D(x))^2 = 2* 1^2 = 2
npt.assert_almost_equal(loss.item(), 2.0)
#dis
loss = gan_loss(input_2, True, is_disc=True) #loss = (y-D(x))^2 = 1^2 = 1
npt.assert_almost_equal(loss.item(), 1.0)
loss = gan_loss(input_2, False, is_disc=True) #loss = (y-D(x))^2 = 2^2 = 4
npt.assert_almost_equal(loss.item(), 4.0)
wgan
Wasserstein GAN
解决的问题:
模式崩溃,生成器生成非常窄的分布,仅覆盖数据分 布中的单一模式。 模式崩溃的含义是生成器只能生成非常相似的样本(例如 ,MNIST中的单个数字),即生成的样本不是多样的。
没有指标可以告诉我们收敛情况。生成器和判别器的 loss并没有告诉我们任何收敛相关信息。当然,我们可以通 过不时地查看生成器生成的数据来监控训练进度。但是, 这是一个手动过程。因此,我们需要有一个可解释的指标 可以告诉我们有关训练的进度。
https://blog.csdn.net/m0_62128864/article/details/124258797
https://zhuanlan.zhihu.com/p/361808267
self.loss = -input.mean() if target else input.mean()
具体计算式:
loss = -[ylogD(x)+(1-y)log(1-D(G(Z)))] 修改为 -[yD(x)-(1-y)(D(G(Z)))] #去掉log,(1-D(G(Z)) 换成-D(G(Z)
# wgan
gan_loss = GANLossComps(
'wgan', loss_weight=2.0, real_label_val=1.0, fake_label_val=0.0)
#Ld = -[yD(x)-(1-y)(D(G(Z)))]
#Lg = - yD(x)
loss = gan_loss(input_2, True, is_disc=False) #-2*(2)=-4
npt.assert_almost_equal(loss.item(), -4.0)
loss = gan_loss(input_2, True, is_disc=True)
npt.assert_almost_equal(loss.item(), -2.0)
loss = gan_loss(input_2, False, is_disc=True)
npt.assert_almost_equal(loss.item(), 2.0)
hinge
对于D来说,只有当D(x) < 1 的正向样本,以及D(G(z)) > -1的负样本才会对结果产生影响
也就是说,只有一些没有被合理区分的样本,才会对梯度产生影响
https://zh.wikipedia.org/zh-cn/Hinge_loss
https://zhuanlan.zhihu.com/p/72195907
self.loss = nn.ReLU()
具体计算式:
Ld = E(max(0,1-D(x)))+E(max(0,1+D(G(z))))
Lg = -E(D(G(z)))
# hinge
gan_loss = GANLossComps(
'hinge', loss_weight=2.0, real_label_val=1.0, fake_label_val=0.0)
# G
loss = gan_loss(input_2, True, is_disc=False) #跟wgan一样直接输出-input.mean()
npt.assert_almost_equal(loss.item(), -4.0)
# D
loss = gan_loss(input_2, True, is_disc=True)
npt.assert_almost_equal(loss.item(), 0.0)
loss = gan_loss(input_2, False, is_disc=True)
npt.assert_almost_equal(loss.item(), 3.0)
总结
vanilla :sigmoid + BCELoss
lsgan : MSE
wgan : 去掉log,(1-D(G(Z)) 换成-D(G(Z),限制L
hinge: 限制E(max(0,1-D(x)))+E(max(0,1+D(G(z))))
附录
文中使用到的GANLossComps
类,作为附录传在下方
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
import torch
import torch.nn as nn
import torch.nn.functional as F
class GANLossComps(nn.Module):
"""Define GAN loss.
Args:
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge',
'wgan-logistic-ns'.
real_label_val (float): The value for real label. Default: 1.0.
fake_label_val (float): The value for fake label. Default: 0.0.
loss_weight (float): Loss weight. Default: 1.0.
Note that loss_weight is only for generators; and it is always 1.0
for discriminators.
"""
def __init__(self,
gan_type: str,
real_label_val: float = 1.0,
fake_label_val: float = 0.0,
loss_weight: float = 1.0) -> None:
super().__init__()
self.gan_type = gan_type
self.loss_weight = loss_weight
self.real_label_val = real_label_val
self.fake_label_val = fake_label_val
if self.gan_type == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif self.gan_type == 'wgan':
self.loss = self._wgan_loss
elif self.gan_type == 'wgan-logistic-ns':
self.loss = self._wgan_logistic_ns_loss
elif self.gan_type == 'hinge':
self.loss = nn.ReLU()
else:
raise NotImplementedError(
f'GAN type {self.gan_type} is not implemented.')
def _wgan_loss(self, input: torch.Tensor, target: bool) -> torch.Tensor:
"""wgan loss.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return -input.mean() if target else input.mean()
def _wgan_logistic_ns_loss(self, input: torch.Tensor,
target: bool) -> torch.Tensor:
"""WGAN loss in logistically non-saturating mode.
This loss is widely used in StyleGANv2.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return F.softplus(-input).mean() if target else F.softplus(
input).mean()
def get_target_label(self, input: torch.Tensor,
target_is_real: bool) -> Union[bool, torch.Tensor]:
"""Get target label.
Args:
input (Tensor): Input tensor.
target_is_real (bool): Whether the target is real or fake.
Returns:
(bool | Tensor): Target tensor. Return bool for wgan, otherwise, \
return Tensor.
"""
if self.gan_type in ['wgan', 'wgan-logistic-ns']:
return target_is_real
target_val = (
self.real_label_val if target_is_real else self.fake_label_val)
return input.new_ones(input.size()) * target_val
def forward(self,
input: torch.Tensor,
target_is_real: bool,
is_disc: bool = False) -> torch.Tensor:
"""
Args:
input (Tensor): The input for the loss module, i.e., the network
prediction.
target_is_real (bool): Whether the targe is real or fake.
is_disc (bool): Whether the loss for discriminators or not.
Default: False.
Returns:
Tensor: GAN loss value.
"""
target_label = self.get_target_label(input, target_is_real)
if self.gan_type == 'hinge':
if is_disc: # for discriminators in hinge-gan
input = -input if target_is_real else input
loss = self.loss(1 + input).mean()
else: # for generators in hinge-gan
loss = -input.mean()
else: # other gan types
loss = self.loss(input, target_label)
# loss_weight is always 1.0 for discriminators
return loss if is_disc else loss * self.loss_weight