点击上方↑↑↑“OpenCV学堂”关注我
来源:公众号 机器之心 授权转载
2014 年,Ian Goodfellow 提出了生成对抗网络(GAN),今天,GAN 已经成为深度学习最热门的方向之一。本文将重点介绍如何利用 Keras 将 GAN 应用于图像去模糊(image deblurring)任务当中。
Keras 代码地址:https://github.com/RaphaelMeudec/deblur-gan
此外,请查阅 DeblurGAN 的原始论文(https://arxiv.org/pdf/1711.07064.pdf)及其 Pytorch 版本实现:https://github.com/KupynOrest/DeblurGAN/。
生成对抗网络简介
在生成对抗网络中,有两个网络互相进行训练。生成器通过生成逼真的虚假输入来误导判别器,而判别器会分辨输入是真实的还是人造的。
GAN 训练流程
训练过程中有三个关键步骤:
使用生成器根据噪声创造虚假输入;
利用真实输入和虚假输入训练判别器;
训练整个模型:该模型是判别器和生成器连接所构建的。
请注意,判别器的权重在第三步中被冻结。
对两个网络进行连接的原因是不存在单独对生成器输出的反馈。我们唯一的衡量标准是判别器是否能接受生成的样本。
以上,我们简要介绍了 GAN 的架构。如果你觉得不够详尽,可以参考这里
生成对抗网络详解与代码演示
数据
Ian Goodfellow 首先应用 GAN 模型生成 MNIST 数据。而在本教程中,我们将生成对抗网络应用于图像去模糊。因此,生成器的输入不是噪声,而是模糊的图像。
我们采用的数据集是 GOPRO 数据集。该数据集包含来自多个街景的人工模糊图像。根据场景的不同,该数据集在不同子文件夹中分类。
你可以下载简单版:https://drive.google.com/file/d/1H0PIXvJH4c40pk7ou6nAwoxuR4Qh_Sa2/view
或完整版:https://drive.google.com/file/d/1SlURvdQsokgsoyTosAaELc4zRjQz9T2U/view
我们首先将图像分配到两个文件夹 A(模糊)B(清晰)中。这种 A&B 的架构对应于原始的 pix2pix 论文。为此我创建了一个自定义的脚本在 github 中执行这个任务,请按照 README 的说明去使用它:
https://github.com/RaphaelMeudec/deblur-gan/blob/master/organize_gopro_dataset.py
模型
训练过程保持不变。首先,让我们看看神经网络的架构吧!
生成器
该生成器旨在重现清晰的图像。该网络基于 ResNet 模块,它不断地追踪关于原始模糊图像的演变。本文同样使用了一个基于 UNet 的版本,但我还没有实现这个版本。这两种模块应该都适合图像去模糊。
DeblurGAN 生成器网络架构,源论文《DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks》。
其核心是应用于原始图像上采样的 9 个 ResNet 模块。让我们来看看 Keras 上的代码实现!
from keras.layers import Input, Conv2D, Activation, BatchNormalization
from keras.layers.merge import Add
from keras.layers.core import Dropout
def res_block(input, filters, kernel_size=(3,3), strides=(1,1), use_dropout=False):
"""
Instanciate a Keras Resnet Block using sequential API.
:param input: Input tensor
:param filters: Number of filters to use
:param kernel_size: Shape of the kernel for the convolution