毕业设计 基于Gan的图像去模糊

本文介绍基于GAN的图像去模糊技术,重点探讨Deblur GAN原理和实验过程。通过利用Deblur GAN网络及GOPRO数据集进行训练,实现了图像从模糊到清晰的转换。文章提供相关代码参考,并分享了实验中遇到的问题与解决方案。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


前言

提示:这里可以添加本文要记录的大概内容:


  • 近几年由于GAN在图像细节生成能力上的优越性开始引起研究者们的注意,并被应用在图像去模糊上,如何提升复原图像的质量以满足实际的应用是目前研究的重点。
    利用目前性能优异的DeblurGAN网络搭配数据集GOPRO训练,实现图像去模糊这一过程。

提示:以下是本篇文章正文内容,下面案例可供参考

一、GAN是什么?

  • 生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。模型通过框架中(至少)两个模块:生成模型(GenerativeModel)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。原始 GAN 理论中,并不要求 G 和D 都是神经网络,只需要是能拟合相应生成和判别的函数即可。但实用中一般均使用深度神经网络作为 G 和 D 。一个优秀的GAN应用需要有良好的训练方法,否则可能由于神经网络模型的自由性而导致输出不理想。

二、Deblur GAN原理

  • DeblurGAN是乌克兰天主教大学的Orest Kupyn等人提出的一种基于GAN方法进行盲运动模糊移除的方法。
  • 受启发于SRGAN与CGAN的成功,将图像模糊移除视为一种特殊的Image2Image任务,DeblurGAN基于wGAN以及内容损失进行训练学习,在SSIM与视觉效果方面,它取得了SOTA性能。
  • 主要贡献:
    提出一种损失与框架,它在运动模糊移除方面取得了SOTA性能;
    提出一种基于随机轨迹的动模糊数据制作方法;
    构建一个新的数据集与评价方法(基于目标检测结果提升)。

三、代码参考

关于DeblurGAN的实现代码,这里给出几个参考:

[1]https://github.com/dongheehand/DeblurGAN-tf

[2]https://github.com/LeeDoYup/DeblurGAN-tf

[3]https://github.com/KupynOrest/DeblurGAN

四、代码实现

mian.py

import tensorflow as tf
from DeblurGAN import DeblurGAN
from mode import *
import argparse

parser = argparse.ArgumentParser()


def str2bool(v):
    return v.lower() in ('true')


## Model specification
parser.add_argument("--n_feats", type=int, default=64)
parser.add_argument("--num_of_down_scale", type=int, default=2)
parser.add_argument("--gen_resblocks", type=int, default=9)
parser.add_argument("--discrim_blocks", type=int, default=3)

## Data specification
parser.add_argument("--train_Sharp_path", type=str, default="./data/train/sharp/")
parser.add_argument("--train_Blur_path", type=str, default="./data/train/blur")
parser.add_argument("--test_Sharp_path", type=str, default="./data/val/val_sharp")
parser.add_argument("--test_Blur_path", type=str, default="./data/val/val_blur")
parser.add_argument("--vgg_path", type=str, default="./vgg19.npy")
parser.add_argument("--patch_size", type=int, default=256)
parser.add_argument("--result_path", type=str, default="./result")
parser.add_argument("--model_path", type=str, default="./model")

## Optimization
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--max_epoch", type=int, default=200)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--decay_step", type=int, default=150)
parser.add_argument("--test_with_train", type=str2bool, default=True)
parser.add_argument("--save_test_result", type=str2bool, default=True)

## Training or test specification
parser.add_argument("--mode", type=str, default="test")
parser.add_argument("--critic_updates", type=int, default=5)
parser.add_argument("--augmentation", type=str2bool, default=False)
parser.add_argument("--load_X", type=int, default=640)
parser.add_argument("--load_Y", type=int, default=360)
parser.add_argument("--fine_tuning", type=str2bool, default=False)
parser.add_argument("--log_freq", type=int, default=1)
parser.add_argument("--model_save_freq", type=int, default=20)
parser.add_argument("--pre_trained_model", type=str, default="./model/")
parser.add_argument("--test_batch", type=int, default=5)
args = parser.parse_args()

model = DeblurGAN(args)
model.build_graph()

data_loader.py

import tensorflow as tf
import os
 
class dataloader():
    
    def __init__(self, args):
 
        self.channel = 3
 
        self.mode = args.mode
        self.patch_size = args.patch_size
        self.batch_size = args.batch_size
        self.train_Sharp_path = args.train_Sharp_path
        self.train_Blur_path = args.train_Blur_path
        self.test_Sharp_path = args.test_Sharp_path
        self.test_Blur_path = args.test_Blur_path
        self.test_with_train = args.test_with_train
        self.test_batch = args.test_batch
        self.load_X = args.load_X
        self.load_Y = args.load_Y
        self.augmentation = args.augmentation
        
    def build_loader(self):
        
        if self.mode == 'train':
        
            tr_sharp_imgs = sorted(os.listdir(self.train_Sharp_path))
            tr_blur_imgs = sorted(os.listdir(self.train_Blur_path))
            tr_sharp_imgs = [os.path.join(self.train_Sharp_path, ele) for ele in tr_sharp_imgs]
            tr_blur_imgs = [os.path.join(self.train_Blur_path, ele) for ele in tr_blur_imgs]
            train_list = (tr_blur_imgs, tr_sharp_imgs)
            
            self.tr_dataset = tf.data.Dataset.from_tensor_slices(train_list)
            self.tr_dataset = self.tr_dataset.map(self._parse, num_parallel_calls = 4).prefetch(32)
            self.tr_dataset = self.tr_dataset.map(self._resize, num_parallel_calls = 4).prefetch(32)
            self.tr_dataset = self.tr_dataset.map(self._get_patch, num_parallel_calls = 4).prefetch(32)
            if self.augmentation:
                self.tr_dataset = self.tr_dataset.map(self._data_augmentation, num_parallel_calls = 4).prefetch(32)
            self.tr_dataset = self.tr_dataset.shuffle(32)
            self.tr_dataset = self.tr_dataset.repeat()
            self.tr_dataset = self.tr_dataset.batch(self.batch_size)
            
            if self.test_with_train:
            
                val_sharp_imgs = sorted(os.listdir(self.test_Sharp_path))
                val_blur_imgs = sorted(os.listdir(self.test_Blur_path))
                val_sharp_imgs = [os.path.join(self.test_Sharp_path, ele) for ele in val_sharp_imgs]
                val_blur_imgs = [os.path.join(self.test_Blur_path, ele) for ele in val_blur_imgs]
                valid_list = (val_blur_imgs, val_sharp_imgs)
 
                self.val_dataset = tf.data.Dataset.from_tensor_slices(valid_list)
                self.val_dataset = self.val_dataset.map(self._parse, num_parallel_calls=4).prefetch(32)
                self.val_dataset = self.val_dataset.batch(self.test_batch)
 
            iterator = tf.data.Iterator.from_structure(self.tr_dataset.output_types, self.tr_dataset.output_shapes)
            self.next_batch = iterator.get_next()
            self.init_op = {
   }
            self.init_op['tr_init'] = iterator.make_initializer(self.tr_dataset)
            
            if self.test_with_train:
                self.init_op['val_init'] = iterator.make_initializer(self.val_dataset)
            
        elif self.mode == 'test':
            
            val_sharp_imgs = sorted(os.listdir(self.test_Sharp_path))
            val_blur_imgs = sorted(os.listdir(self.test_Blur_path))
            val_sharp_imgs = [os.path.join(self.test_Sharp_path, ele) for ele in val_sharp_imgs]
            val_blur_imgs = [os.path.join(self.test_Blur_path, ele) for ele in val_blur_imgs]
            valid_list = (val_blur_imgs, val_sharp_imgs)
            
            self.val_dataset = tf.data.Dataset.from_tensor_slices(valid_list)
            self.val_dataset = self.val_dataset.map(self._parse, num_parallel_calls=4).prefetch(32)
            self.val_dataset = self.val_dataset.batch(1)
            
            iterator = tf.data.Iterator
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值