前言
提示:这里可以添加本文要记录的大概内容:
- 近几年由于GAN在图像细节生成能力上的优越性开始引起研究者们的注意,并被应用在图像去模糊上,如何提升复原图像的质量以满足实际的应用是目前研究的重点。
利用目前性能优异的DeblurGAN网络搭配数据集GOPRO训练,实现图像去模糊这一过程。
提示:以下是本篇文章正文内容,下面案例可供参考
一、GAN是什么?
- 生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。模型通过框架中(至少)两个模块:生成模型(GenerativeModel)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。原始 GAN 理论中,并不要求 G 和D 都是神经网络,只需要是能拟合相应生成和判别的函数即可。但实用中一般均使用深度神经网络作为 G 和 D 。一个优秀的GAN应用需要有良好的训练方法,否则可能由于神经网络模型的自由性而导致输出不理想。
二、Deblur GAN原理
- DeblurGAN是乌克兰天主教大学的Orest Kupyn等人提出的一种基于GAN方法进行盲运动模糊移除的方法。
- 受启发于SRGAN与CGAN的成功,将图像模糊移除视为一种特殊的Image2Image任务,DeblurGAN基于wGAN以及内容损失进行训练学习,在SSIM与视觉效果方面,它取得了SOTA性能。
- 主要贡献:
提出一种损失与框架,它在运动模糊移除方面取得了SOTA性能;
提出一种基于随机轨迹的动模糊数据制作方法;
构建一个新的数据集与评价方法(基于目标检测结果提升)。
三、代码参考
关于DeblurGAN的实现代码,这里给出几个参考:
[1]https://github.com/dongheehand/DeblurGAN-tf
[2]https://github.com/LeeDoYup/DeblurGAN-tf
[3]https://github.com/KupynOrest/DeblurGAN
四、代码实现
mian.py
import tensorflow as tf
from DeblurGAN import DeblurGAN
from mode import *
import argparse
parser = argparse.ArgumentParser()
def str2bool(v):
return v.lower() in ('true')
## Model specification
parser.add_argument("--n_feats", type=int, default=64)
parser.add_argument("--num_of_down_scale", type=int, default=2)
parser.add_argument("--gen_resblocks", type=int, default=9)
parser.add_argument("--discrim_blocks", type=int, default=3)
## Data specification
parser.add_argument("--train_Sharp_path", type=str, default="./data/train/sharp/")
parser.add_argument("--train_Blur_path", type=str, default="./data/train/blur")
parser.add_argument("--test_Sharp_path", type=str, default="./data/val/val_sharp")
parser.add_argument("--test_Blur_path", type=str, default="./data/val/val_blur")
parser.add_argument("--vgg_path", type=str, default="./vgg19.npy")
parser.add_argument("--patch_size", type=int, default=256)
parser.add_argument("--result_path", type=str, default="./result")
parser.add_argument("--model_path", type=str, default="./model")
## Optimization
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--max_epoch", type=int, default=200)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--decay_step", type=int, default=150)
parser.add_argument("--test_with_train", type=str2bool, default=True)
parser.add_argument("--save_test_result", type=str2bool, default=True)
## Training or test specification
parser.add_argument("--mode", type=str, default="test")
parser.add_argument("--critic_updates", type=int, default=5)
parser.add_argument("--augmentation", type=str2bool, default=False)
parser.add_argument("--load_X", type=int, default=640)
parser.add_argument("--load_Y", type=int, default=360)
parser.add_argument("--fine_tuning", type=str2bool, default=False)
parser.add_argument("--log_freq", type=int, default=1)
parser.add_argument("--model_save_freq", type=int, default=20)
parser.add_argument("--pre_trained_model", type=str, default="./model/")
parser.add_argument("--test_batch", type=int, default=5)
args = parser.parse_args()
model = DeblurGAN(args)
model.build_graph()
data_loader.py
import tensorflow as tf
import os
class dataloader():
def __init__(self, args):
self.channel = 3
self.mode = args.mode
self.patch_size = args.patch_size
self.batch_size = args.batch_size
self.train_Sharp_path = args.train_Sharp_path
self.train_Blur_path = args.train_Blur_path
self.test_Sharp_path = args.test_Sharp_path
self.test_Blur_path = args.test_Blur_path
self.test_with_train = args.test_with_train
self.test_batch = args.test_batch
self.load_X = args.load_X
self.load_Y = args.load_Y
self.augmentation = args.augmentation
def build_loader(self):
if self.mode == 'train':
tr_sharp_imgs = sorted(os.listdir(self.train_Sharp_path))
tr_blur_imgs = sorted(os.listdir(self.train_Blur_path))
tr_sharp_imgs = [os.path.join(self.train_Sharp_path, ele) for ele in tr_sharp_imgs]
tr_blur_imgs = [os.path.join(self.train_Blur_path, ele) for ele in tr_blur_imgs]
train_list = (tr_blur_imgs, tr_sharp_imgs)
self.tr_dataset = tf.data.Dataset.from_tensor_slices(train_list)
self.tr_dataset = self.tr_dataset.map(self._parse, num_parallel_calls = 4).prefetch(32)
self.tr_dataset = self.tr_dataset.map(self._resize, num_parallel_calls = 4).prefetch(32)
self.tr_dataset = self.tr_dataset.map(self._get_patch, num_parallel_calls = 4).prefetch(32)
if self.augmentation:
self.tr_dataset = self.tr_dataset.map(self._data_augmentation, num_parallel_calls = 4).prefetch(32)
self.tr_dataset = self.tr_dataset.shuffle(32)
self.tr_dataset = self.tr_dataset.repeat()
self.tr_dataset = self.tr_dataset.batch(self.batch_size)
if self.test_with_train:
val_sharp_imgs = sorted(os.listdir(self.test_Sharp_path))
val_blur_imgs = sorted(os.listdir(self.test_Blur_path))
val_sharp_imgs = [os.path.join(self.test_Sharp_path, ele) for ele in val_sharp_imgs]
val_blur_imgs = [os.path.join(self.test_Blur_path, ele) for ele in val_blur_imgs]
valid_list = (val_blur_imgs, val_sharp_imgs)
self.val_dataset = tf.data.Dataset.from_tensor_slices(valid_list)
self.val_dataset = self.val_dataset.map(self._parse, num_parallel_calls=4).prefetch(32)
self.val_dataset = self.val_dataset.batch(self.test_batch)
iterator = tf.data.Iterator.from_structure(self.tr_dataset.output_types, self.tr_dataset.output_shapes)
self.next_batch = iterator.get_next()
self.init_op = {
}
self.init_op['tr_init'] = iterator.make_initializer(self.tr_dataset)
if self.test_with_train:
self.init_op['val_init'] = iterator.make_initializer(self.val_dataset)
elif self.mode == 'test':
val_sharp_imgs = sorted(os.listdir(self.test_Sharp_path))
val_blur_imgs = sorted(os.listdir(self.test_Blur_path))
val_sharp_imgs = [os.path.join(self.test_Sharp_path, ele) for ele in val_sharp_imgs]
val_blur_imgs = [os.path.join(self.test_Blur_path, ele) for ele in val_blur_imgs]
valid_list = (val_blur_imgs, val_sharp_imgs)
self.val_dataset = tf.data.Dataset.from_tensor_slices(valid_list)
self.val_dataset = self.val_dataset.map(self._parse, num_parallel_calls=4).prefetch(32)
self.val_dataset = self.val_dataset.batch(1)
iterator = tf.data.Iterator