SRGAN-超分辨率图像复原

版权声明:本文为博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/l7H9JA4/article/details/88360991

640?wx_fmt=gif 

  作者: 陈  扬

编辑:黄俊嘉


github:https://github.com/OUCMachineLearning/OUCML/blob/master/GAN/srgan_celebA/srgan.py

arxiv:https://arxiv.org/abs/1609.04802


我的研究方向:GAN

大家好,我是中国海洋大学的陈扬。在遥远的九月份,我开始做了keras的系列教程,现在我主要的研究方向转到了生成对抗网络,生成对抗网络的代码实现和训练机制比分类模型都要复杂和难入门.之前一段时间时间一直在帮璇姐跑cvpr的实验代码,做了蛮多的对比实验,其中我就发现了,keras的代码实现和可阅读性很好,搭生成对抗网络网络GAN就好像搭乐高积木一样有趣哦。不只是demo哦,我还会在接下来的一系列 keras教程中教你搭建Alexnet,Vggnet,Resnet,DCGAN,ACGAN,CGAN,SRGAN,等等实际的模型并且教你如何在GPU服务器上运行。


前  言

上个星期发了一篇有关GAN入门的文章,同学们都觉得挺有趣的,上一次我写了如何理解最基础的GAN的原理,今天我给大家带来的是如何运用强大的GAN做一些好玩的应用.

超分辨率复原一直是计算机视觉领域一个十分热门的研究方向,在商业上也有着很大的用武之地,随着2014年goodflew那篇惊世骇俗的GAN发表出来,GAN伴随着CNN一起,可谓是乘风破浪,衍生出来琳琅满目的各种应用.

640?wx_fmt=png

简单的来说,就给定一个低分辨率图片作为噪声z的输入,通过生成器的变换把噪声的概率分布空间尽可能的去拟合真实数据的分布空间.



01

基本框架

640?wx_fmt=jpeg

在这里,我们把生成器看的目标看成是要以次充好,判别器的目标是要明辨真假.

我们可以的看到,在生成器的前6层网络中,我们运用了残差块,为什么要用残差块呢?

因为我们可以从上图看出来,当损失函数从判别器开始反向传播会生成器的时候,实际上进过来很多层,我们知道越深的网络隐藏参数越多,在反向传播的过程中也越容易梯度弥散.而且残差连接的方法,就有效的保证了我们梯度信息能够有效的传递而增强生成对抗网络的鲁棒性.(事实上沃瑟斯坦loss也可以增强GAN训练的鲁棒性,以后会写)


02

celebA

再来聊聊今天用的数据集,这是Celeb-A,里面有大量的带标注信息的明星人脸.在目前很多的GAN的应用中,都是用CelebA作为基础的数据集,这个数据集大概在1.2G左右,可以在kaggle上下载.

640?wx_fmt=png

  • 浏览数据集

[https://www.kaggle.com/jessicali9530/celeba-dataset]

A popular component of computer vision and deep learning revolves around identifying faces for various applications from logging into your phone with your face or searching through surveillance images for a particular suspect. This dataset is great for training and testing models for face detection, particularly for recognising facial attributes such as finding people with brown hair, are smiling, or wearing glasses. Images cover large pose variations, background clutter, diverse people, supported by a large quantity of images and rich annotations. This data was originally collected by researchers at MMLAB, The Chinese University of Hong Kong (specific reference in Acknowledgment section).


03

 Overall

202,599 number of face images of various celebrities

10,177 unique identities, but names of identities are not given

40 binary attribute annotations per image

5 landmark locations

04

Super-Resolution IMAGE

简单点说,就是给你一张模糊的图片,让你复原一张高清的图片.

640?wx_fmt=png

05

我们如何用生成对抗网络来做呢?

这个时候,我们可以把LRimg看成是一个噪声z的输入,G生成的是一个FAKE-HRimg,我们让D分辨fake-HRimg and original HRimg.


06

定义一个目标函数

Our ultimate goal is to train a generating function G that estimates for a given LR input image its corresponding HR counterpart. To achieve this, we train a generator network as a feed-forward CNN GθG parametrized by θG. Here θG = {W1:L ; b1:L } denotes the weights and biases of a L-layer deep network and is obtained by optimizing a SR-specific

loss function lSR. For training images IHR , n = 1, . . . , N n

withcorrespondingILR,n=1,...,N,wesolve:

640?wx_fmt=png


07

提出perceptual loss  

作者认为这更接近人的主观感受,因为使用pixel-wise的MSE使得图像变得平滑,而如果先用VGG来抓取到高级特征(feature)表示,再对feature使用MSE,可以更好的抓取不变特征。

640?wx_fmt=png

640?wx_fmt=jpeg

640?wx_fmt=jpeg

640?wx_fmt=jpeg


  • 核心公式

640?wx_fmt=jpeg

这个公式我们要分成两个部分来看,先看前半部分:

640?wx_fmt=jpeg

这个公式的意思是,先看加号前面640?wx_fmt=png,我们希望D最大,所以640?wx_fmt=png应该最大,意味着我的判别器可以很好的识别出,真实的高分辨率图像是"true",在看加号后面的640?wx_fmt=png,要让log尽可能的大,需要的是ΘD(ΘG(z))尽可能的小,意味着我们生成模型复原的图片应该尽可能的被判别模型视为"FALSE".

640?wx_fmt=jpeg

再看后半部分部分

我们应该让G尽可能的小,加号前面的式子并没有G,所以无关,在看加号后面的式子

640?wx_fmt=png

,要让ΘG尽可能地小,就要ΘD(ΘG(Z))尽可能的大,也就是说本来就一张低分辨率生成的图片,判别器却被迷惑了,以为是一张原始的高分辨率图片.这就是所谓的以次充好.


08

网络设计

640?wx_fmt=png


09

loss函数

###vgg用于提取特征
self.vgg.compile(loss='mse',
optimizer=optimizer,
metrics=['accuracy'])
###生成器
self.combined.compile(loss=['binary_crossentropy', 'mse'],
loss_weights=[1e-3, 1],
optimizer=optimizer)
###判别器
self.discriminator.compile(loss='mse',
optimizer=optimizer,
metrics=['accuracy'])


10

train

  • 训练判别器

d_loss_real = self.discriminator.train_on_batch(imgs_hr, valid)
d_loss_fake = self.discriminator.train_on_batch(fake_hr, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)


  • 训练生成器

image_features = self.vgg.predict(imgs_hr)
# Train the generators
g_loss = self.combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])


11

实际结果

640?wx_fmt=png5000 batchsize


12

对比实验结果

640?wx_fmt=png


13

谢谢大家

嘻嘻嘻,大家要是喜欢这个系列的话就给我个小小的赞哦,等到我期末考考完试文件录视频来讲如何从零开始搭建生成对抗网络,emmmm现在大二学业压力确实大了,不过我们的创造会一如既往的用心做下去的,感谢你们的陪伴,也是我持续创造的动力源泉.



 640?wx_fmt=gif

END


往期回顾之作者陈扬

【1】手把手教你用keras搭建GAN

【2】手把手教你用keras--CNN网络识别cifar10

【3】手把手教你用keras--像搭乐高积木一样搭建神经网络(lenet)

【4】机器学习论文笔记(七):一种简单有效的网络结构搜索

【5】机器学习论文笔记-如何利用高效搜索算法来搜索网络的拓扑结构




机器学习算法工程师


                            一个用心的公众号

640?wx_fmt=jpeg

长按,识别,加关注

进群,学习,得帮助

你的关注,我们的热度,

我们一定给你学习最大的帮助



640?wx_fmt=png你点的每个赞,我都认真当成了喜欢


展开阅读全文

没有更多推荐了,返回首页