1. 概述
在真实数据超分任务上,从SRGAN开始,Loss函数基本是Pixel loss + GAN loss + Perceptual loss的组合。
与生成任务不同,对于超分这种复原任务,如果只使用Gan loss或者GAN loss的权重比较大的话,效果就比较差。
SRGAN成功的两个关键点:1. 引入了感知损失函数(Perceptual Loss),它是让生成图像产生细节的关键,而不是对抗损失函数。2. 将对抗损失函数的权重调小,让它不能影响训练的方向,只会微调生成图像的清晰度,消除感知损失函数带来的噪声。参见底层视觉之美。
在实践中,一般gan loss的权重设置为Pixel loss的千分之一。
2. 超分中的判别器
判别器一般来说有三种:
- 分类网络 vgg,resnet等
最后一层输出输出一个数字,代表整张图的判别结果 - Patch gan
最后一层不再输出一个数字,而是输出1xnxn的特征图,其中的每一个数字代表了原图中一个patch的判别结果;最后的loss通过对这nxn个点求均值得到; - U-Net discriminator with spectral normalization (SN).
在Real ESRGAN中提出的,因为unet的输入分辨率和输出分辨率一致,相当于unet判别器对每个像素进行了判别,最后的loss求均值得到;引入spectral normalization 是为了稳定训练,同时可以消除一些artifacts;
3. 超分中的几种 Gan loss
3.1 Vanilla GAN
最原始的gan loss,判别器做的是二分类任务,判别器的最后输出经过sigmoid后计算交叉熵;一般用
self.loss = nn.BCEWithLogitsLoss()
实现,其相当于sigmoid + 交叉熵;
3.2 LSGAN (最小平方gan)
不去算sigmoid和交叉熵,而是直接算判别器预测输出与真实标签值的MSE;一般用self.loss = nn.MSELoss()
3.3 WGAN loss
WGAN是对原始的GAN的改进,优化了其会发生梯度消失训练不稳定的问题,原始的GAN最小化生成器loss等价于最小化真实分布P_r与生成分布P_g之间的JS散度 → WGAN最小化真实分布P_r与生成分布P_g之间的Wasserstein距离;
具体来说,WGAN去掉了sigmoid, 同时也不再计算交叉熵,而是直接返回D(x)的均值。因为一般来说,都是最小化loss,对于真实样本直接输出-input.mean()
;对于生成样本,如果是优化生成器的时候,wgan loss为-input.mean()
,如果是优化判别器,则输出input.mean()
;代码如下;
def wgan_loss(input, target):
# target is boolean
return -1 * input.mean() if target else input.mean()
3.4 RAGAN (相对Gan)
衡量的是真实数据比生成数据真实的概率,也就是说原始的GAN是将判别器的输出直接计算loss,而RAGAN会先计算真实样本的判别器输出和生成样本的判别器输出,做差值后再进行loss计算;比如生成器loss如下:
D_loss = self.D_lossfn_weight * (
self.D_lossfn(pred_d_real - torch.mean(pred_g_fake, 0, True), False) +
self.D_lossfn(pred_g_fake - torch.mean(pred_d_real, 0, True), True)) / 2
3.5 代码(来自KAIR)
- Loss函数定义代码
class GANLoss(nn.Module): def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): super(GANLoss, self).__init__() self.gan_type = gan_type.lower() self.real_label_val = real_label_val self.fake_label_val = fake_label_val # 原始gan和ragan都是二分类 if self.gan_type == 'gan' or self.gan_type == 'ragan': self.loss = nn.BCEWithLogitsLoss() elif self.gan_type == 'lsgan': self.loss = nn.MSELoss() elif self.gan_type == 'wgan': def wgan_loss(input, target): # target is boolean return -1 * input.mean() if target else input.mean() self.loss = wgan_loss elif self.gan_type == 'softplusgan': def softplusgan_loss(input, target): # target is boolean return F.softplus(-input).mean() if target else F.softplus(input).mean() self.loss = softplusgan_loss else: raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) def get_target_label(self, input, target_is_real): if self.gan_type in ['wgan', 'softplusgan']: return target_is_real # 返回标签,如果target_is_real为true,则返回全1的标签;如果为false则返回全0的标签 if target_is_real: return torch.empty_like(input).fill_(self.real_label_val) else: return torch.empty_like(input).fill_(self.fake_label_val) def forward(self, input, target_is_real): target_label = self.get_target_label(input, target_is_real) loss = self.loss(input, target_label) return loss ```
- 判别器Loss计算代码
if self.opt_train['gan_type'] in ['gan', 'lsgan', 'wgan', 'softplusgan']: # real pred_d_real = self.netD(self.H) # 1) real data l_d_real = self.D_lossfn(pred_d_real, True) l_d_real.backward() # fake pred_d_fake = self.netD(self.E.detach().clone()) # 2) fake data, detach to avoid BP to G l_d_fake = self.D_lossfn(pred_d_fake, False) l_d_fake.backward() elif self.opt_train['gan_type'] == 'ragan': # real pred_d_fake = self.netD(self.E).detach() # 1) fake data, detach to avoid BP to G pred_d_real = self.netD(self.H) # 2) real data l_d_real = 0.5 * self.D_lossfn(pred_d_real - torch.mean(pred_d_fake, 0, True), True) l_d_real.backward() # fake pred_d_fake = self.netD(self.E.detach()) l_d_fake = 0.5 * self.D_lossfn(pred_d_fake - torch.mean(pred_d_real.detach(), 0, True), False) l_d_fake.backward() ```
- 生成器loss计算代码
if self.opt['train']['gan_type'] in ['gan', 'lsgan', 'wgan', 'softplusgan']: pred_g_fake = self.netD(self.E) D_loss = self.D_lossfn_weight * self.D_lossfn(pred_g_fake, True) elif self.opt['train']['gan_type'] == 'ragan': pred_d_real = self.netD(self.H).detach() pred_g_fake = self.netD(self.E) # 相对判别器 D_loss = self.D_lossfn_weight * ( self.D_lossfn(pred_d_real - torch.mean(pred_g_fake, 0, True), False) + self.D_lossfn(pred_g_fake - torch.mean(pred_d_real, 0, True), True)) / 2