把 GAN 运用在其他视觉任务上 | 图像超分经典网络 SRGAN 解析
GAN 不仅能生成图片,还能运用在其他视觉任务上
图像超分 SRGAN 解析,教你把 GAN 运用在其他视觉任务上
导读:
本文来自社区投稿,作者周弈帆。
生成对抗网络(GAN)是一类非常有趣的神经网络。借助 GAN,计算机能够生成逼真的图片。近年来有许多“ AI 绘画”的新闻,这些应用大多是通过 GAN 实现的。实际上,GAN 不仅能做图像生成,还能辅助其他输入信息不足的视觉任务。比如 SRGAN,就是把 GAN 应用在超分辨率(SR)任务上的代表之作。
在这篇文章中,作者将主要面向深度学习的初学者,介绍 SRGAN[1] 这篇论文,同时分享以下知识:
- GAN 的原理与训练过程
- 感知误差(Perceptual Loss)
- 基于的 GAN 的 SR 模型框架
目前 OpenMMLab 的 MMEditing 算法库已经支持了 SRGAN,讲完了上述知识后,作者还会解读一下 MMEditing 的 SRGAN 的训练代码。看懂这份代码能够加深对 SRGAN 训练算法的理解。
下面就让我们进入今天的正题吧~
SRGAN 核心思想
早期超分辨率方法的优化目标都是降低低清图像和高清图像之间的均方误差。降低均方误差,确实让增强图像和原高清图像的相似度更高。但是,图像的相似度指标高并不能代表图像的增强质量就很高。下图显示了插值、优化均方误差、SRGAN、原图这四个图像输出结果(括号里的相似度指标是 PSNR 和 SSIM)。
从图中可以看出,优化均方误差虽然能让相似度指标升高,但图像的细节十分模糊,尤其是纹理比较密集的高频区域。相比之下,SRGAN 增强出来的图像虽然相似度不高,但看起来更加清晰。
为什么 SRGAN 的增强结果那么清楚呢?这是因为 SRGAN 使用了一套新的优化目标。SRGAN 使用的损失函数既包括了 GAN 误差,也包括了感知误差。这套新的优化目标能够让网络生成看起来更清楚的图片,而不仅仅是和原高清图像相似度更高的图片。
下面,我们来一步一步学习 SRGAN 的框架。
GAN 的原理
基于 GAN 的超分辨率网络
基于感知的内容误差
在介绍 SRGAN 的内容误差之前,需要对“内容误差”和“感知误差”这两个名词做一个澄清。在 SRGAN的原文章中,作者把内容误差和对抗误差之和叫做感知误差。但是,后续的大部分文献只把这种内容误差叫做感知误差,不会把内容误差和对抗误差放在一起称呼。在后文中,我也会用“感知误差”来指代 SRGAN 中的“内容误差”。
在深度卷积神经网络(CNN)火起来后,人们开始研究为什么 CNN 能够和人类一样识别出图像。经实验,人们发现两幅图像经 VGG(一个经典的 CNN)的某些中间层的输出越相似,两幅图像从观感上也越相似。这种相似度并不是基于某种数学指标,而是和人的感知非常类似。
VGG 的这种“感知性”被运用在了风格迁移等任务上。也有人考虑把这种感知上的误差运用到超分辨率任务上,并取得了不错的结果[3]。下图是真值、插值、基于逐像素误差、基于感知误差的四个超分辨率结果。
SRGAN 的其他模块
定义好了误差函数,只要在决定好网络结构就可以开始训练网络了。SRGAN 使用的生成网络和判别网络的结构如下:
判别网络就是一个平平无奇的二分类网络,架构上没有什么创新。而生成网络则先用几个残差块提取特征,最后用一种超分辨率任务中常用的上采样模块 PixelShuffle 对原图像的尺寸翻倍两次,最后输出一个边长放大 4 倍的高清图像。
SRGAN 的这种网络结构在当时确实取得了不错的结果。但是,很快就有后续研究提出了更好的网络架构。比如 ESRGAN[4] 去掉了生成网络的 BN 层,提出了一种叫做 RRDB 的高级模块。基于 RRDB的生成网络有着更好的生成效果。
不仅是网络架构,SRGAN 的其他细节也得到了后续研究的改进。GAN 误差的公式、总误差的公式、高清图像退化成低清图像的数据增强算法……这些子模块都被后续研究改进了。但是,SRGAN 这种基于 GAN 的训练架构一直没有发生改变。有了 SRGAN 的代码,想复现一些更新的超分辨率网络时,往往只需要换一下生成器的结构,或者改一改误差的公式就行了。大部分的训练代码是不用改变的。
总结
附录:MMEditing 中的 SRGAN
MMEditing 中的 SRGAN 写在 mmedit/models/restorers/srgan.py
这个文件里。学习训练逻辑时,我们只需要关注 SRGAN
类的 train_step
方法即可。
以下是 train_step
的源代码(我的 MMEditing **** 版本是 v0.15.1)。
MMEditing 中的 SRGAN 写在 mmedit/models/restorers/srgan.py
这个文件里。学习训练逻辑时,我们只需要关注 SRGAN
类的 train_step
方法即可。
以下是 train_step
的源代码(我的 MMEditing **** 版本是 v0.15.1)。
def train_step(self, data_batch, optimizer):
"""Train step.
Args:
data_batch (dict): A batch of data.
optimizer (obj): Optimizer.
Returns:
dict: Returned output.
"""
# data
lq = data_batch['lq']
gt = data_batch['gt']
# generator
fake_g_output = self.generator(lq)
losses = dict()
log_vars = dict()
# no updates to discriminator parameters.
set_requires_grad(self.discriminator, False)
if (self.step_counter % self.disc_steps == 0
and self.step_counter >= self.disc_init_steps):
if self.pixel_loss:
losses['loss_pix'] = self.pixel_loss(fake_g_output, gt)
if self.perceptual_loss:
loss_percep, loss_style = self.perceptual_loss(
fake_g_output, gt)
if loss_percep is not None:
losses['loss_perceptual'] = loss_percep
if loss_style is not None:
losses['loss_style'] = loss_style
# gan loss for generator
fake_g_pred = self.discriminator(fake_g_output)
losses['loss_gan'] = self.gan_loss(
fake_g_pred, target_is_real=True, is_disc=False)
# parse loss
loss_g, log_vars_g = self.parse_losses(losses)
log_vars.update(log_vars_g)
# optimize
optimizer['generator'].zero_grad()
loss_g.backward()
optimizer['generator'].step()
# discriminator
set_requires_grad(self.discriminator, True)
# real
real_d_pred = self.discriminator(gt)
loss_d_real = self.gan_loss(
real_d_pred, target_is_real=True, is_disc=True)
loss_d, log_vars_d = self.parse_losses(dict(loss_d_real=loss_d_real))
optimizer['discriminator'].zero_grad()
loss_d.backward()
log_vars.update(log_vars_d)
# fake
fake_d_pred = self.discriminator(fake_g_output.detach())
loss_d_fake = self.gan_loss(
fake_d_pred, target_is_real=False, is_disc=True)
loss_d, log_vars_d = self.parse_losses(dict(loss_d_fake=loss_d_fake))
loss_d.backward()
log_vars.update(log_vars_d)
optimizer['discriminator'].step()
self.step_counter += 1
log_vars.pop('loss') # remove the unnecessary 'loss'
outputs = dict(
log_vars=log_vars,
num_samples=len(gt.data),
results=dict(lq=lq.cpu(), gt=gt.cpu(), output=fake_g_output.cpu()))
return outputs
一开始,图像输出都在词典 data_batch
里。函数先把低清图 lq
和高清的真值 gt
从词典里取出。
# data
lq = data_batch['lq']
gt = data_batch['gt']
之后,函数计算了G(Ilq)G(I^{lq}),为后续 loss 的计算做准备。
# generator
fake_g_output = self.generator(lq)
接下来,是优化生成器 self.generator
的逻辑。这里面有一些函数调用,我们可以不管它们的实现,大概理解整段代码的意思就行了。
losses = dict()
log_vars = dict()
# no updates to discriminator parameters.
set_requires_grad(self.discriminator, False)
if (self.step_counter % self.disc_steps == 0
and self.step_counter >= self.disc_init_steps):
if self.pixel_loss:
losses['loss_pix'] = self.pixel_loss(fake_g_output, gt)
if self.perceptual_loss:
loss_percep, loss_style = self.perceptual_loss(
fake_g_output, gt)
if loss_percep is not None:
losses['loss_perceptual'] = loss_percep
if loss_style is not None:
losses['loss_style'] = loss_style
# gan loss for generator
fake_g_pred = self.discriminator(fake_g_output)
losses['loss_gan'] = self.gan_loss(
fake_g_pred, target_is_real=True, is_disc=False)
# parse loss
loss_g, log_vars_g = self.parse_losses(losses)
log_vars.update(log_vars_g)
# optimize
optimizer['generator'].zero_grad()
loss_g.backward()
optimizer['generator'].step()
为了只训练生成器,要用下面的代码关闭判别器的训练。
# no updates to discriminator parameters.
set_requires_grad(self.discriminator, False)
正文说过,训练 GAN 时一般要先训好判别器,且训练判别器多于训练生成器。因此,下面的 if 语句可以让判别器训练了 self.disc_init_steps
步后,每训练 self.disc_steps
步判别器再训练一步生成器。
if (self.step_counter % self.disc_steps == 0
and self.step_counter >= self.disc_init_steps):
if 语句块里分别计算了逐像素误差(比如均方误差和 L1 误差)、感知误差、GAN 误差。虽然 SRGAN 完全抛弃了逐像素误差,但实际训练时我们还是可以按一定比例加上这个误差。这些误差最后会用于训练生成器。
if self.pixel_loss:
losses['loss_pix'] = self.pixel_loss(fake_g_output, gt)
if self.perceptual_loss:
loss_percep, loss_style = self.perceptual_loss(
fake_g_output, gt)
if loss_percep is not None:
losses['loss_perceptual'] = loss_percep
if loss_style is not None:
losses['loss_style'] = loss_style
# gan loss for generator
fake_g_pred = self.discriminator(fake_g_output)
losses['loss_gan'] = self.gan_loss(
fake_g_pred, target_is_real=True, is_disc=False)
# parse loss
loss_g, log_vars_g = self.parse_losses(losses)
log_vars.update(log_vars_g)
# optimize
optimizer['generator'].zero_grad()
loss_g.backward()
optimizer['generator'].step()
训练完生成器后,要训练判别器。和生成器的误差计算方法类似,判别器的训练代码如下:
# discriminator
set_requires_grad(self.discriminator, True)
# real
real_d_pred = self.discriminator(gt)
loss_d_real = self.gan_loss(
real_d_pred, target_is_real=True, is_disc=True)
loss_d, log_vars_d = self.parse_losses(dict(loss_d_real=loss_d_real))
optimizer['discriminator'].zero_grad()
loss_d.backward()
log_vars.update(log_vars_d)
# fake
fake_d_pred = self.discriminator(fake_g_output.detach())
loss_d_fake = self.gan_loss(
fake_d_pred, target_is_real=False, is_disc=True)
loss_d, log_vars_d = self.parse_losses(dict(loss_d_fake=loss_d_fake))
loss_d.backward()
log_vars.update(log_vars_d)
optimizer['discriminator'].step()
这段代码有两个重点:
-
在训练判别器时,要用
set_requires_grad(self.discriminator, True)
开启判别器的梯度计算。 -
fake_d_pred = self.discriminator(fake_g_output.detach())
这一行的detach()
很关键,detach()
可以中断某张量的梯度跟踪。fake_g_output
是由生成器算出来的,如果不把这个张量的梯度跟踪切断掉,在优化判别器时生成器的参数也会跟着优化。
函数的最后部分是一些和 MMEditing 其他代码逻辑的交互,和 SRGAN 本身没什么关联。
self.step_counter += 1
log_vars.pop('loss') # remove the unnecessary 'loss'
outputs = dict(
log_vars=log_vars,
num_samples=len(gt.data),
results=dict(lq=lq.cpu(), gt=gt.cpu(), output=fake_g_output.cpu()))
return outputs
只要理解了本文的误差计算公式,再看懂了这段代码是如何训练判别器和生成器的,就算是完全理解了 SRGAN 的核心思想了。
参考资料
[1] (SRGAN): Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
[2] (GAN): Generative Adversarial Nets
[3] (Perceptual Loss):Perceptual Losses for Real-Time Style Transfer and Super-Resolution
[4] (ESRGAN): ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
作者:OpenMMLab
文章来源:稀土掘金
推荐阅读
更多芯擎AI开发板干货请关注芯擎AI开发板专栏。欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。