最近在学习图像超分辨的任务,看了一些论文,从较古老的SRGAN到最近IEEE上的新论文都进行了简单的阅读,感觉要学习的东西还是挺多呀!!
这篇文章主要对基于GAN的图像超分辨领域的开山之作SRGAN进行简单的梳理,加深和扩宽一下自己的思路。
目前SRGAN的代码已经开源,下面给出链接(至少2023年可以用)
论文地址:1609.04802.pdf (arxiv.org)
实现图像超分辨其实思路十分清晰。
首先输入低分辨率图像lr,通过cnn产生一个生成的假的高分辨率图像fr,然后用这个fr和你自己真正的高分辨率图像hr进行对比,得到你假的fr和真正hr的差距(其实就是损失函数),然后用这个差距来改善我们的cnn。什么时候你的假的fr和真正hr都几乎没有差距了,那就说明训练成功了。
那么从上面这段话中,我们很容易得到一个基本的训练要求,那就是你得有巨多巨多巨多成对的高分辨率图像和低分辨率图像,这种训练也是所谓的监督学习。
其实关注GAN网络,我个人认为主要就关注三个部分:1.生成器网络 2.判别器网络 3.损失函数,任何网络只要搞清楚这三部分其实都会迎刃而解。
SRGAN生成器和判别器网络都较为基础,利用残差、卷积、BN层搭建。
1.生成器网络及损失函数
网络的生成器如上图所示,核心是中间的五个残差模块,每一个残差包括卷积、BN、PReLU、卷积、BN组成,通过跳层实现残差。
图像的上采样通过PixelShuffle,如下图所示。PixelShuffle为目前图像领域最常用的上采样方法,主要利用层数换分辨率,将多层信息整合到同一层中,所以若需要n倍放大则需要n**2中间层。
生成器损失函数由对抗损失adversarial_loss 、图像MSE损失image_loss 、内容损失perception_loss 和平滑损失tv_loss四部分组成,
loss = image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss
adversarial_loss为生成图像过判别器后的值,由于希望过判别器后的值越大越好,所以adversarial_loss=torch.mean(1 - out_labels),即out_labels越大生成器模型就越接近成功了。image_loss、tv_loss为MES损失和平滑,可直接利用python库继续计算。内容损失perception_loss常利用VGG网络计算生成图像和真实图像在高层的像素距离,内容损失也可用RESNET进行计算。
2.判别器及损失函数
网络的判别器为多个卷积层组成,最后过sigmoid得到归一化的值进行判别。
损失函数 d_loss = 1 - real_loss_out + fake_loss_out,判别器对真实图像判别为高,对生成图像判别为低。
整个代码训练后的效果不算特别好,但是对比以往的各种插值提高图像分辨率的方法,整体的效果还是提升了很多。下图是我的一些测试图像演示。
这就是SRGAN的简单介绍了,之后还会继续跟新前沿的图像超分辨算法介绍和感悟,一起学习进步。
预训练模型我放这里啦(CSDN下载应该不需要积分啥的吧,我反正选的不要,如果大家下载有问题直接找我哈):【免费】Srgan训练模型,用于SRGAN网络资源-CSDN文库