SRGAN 图像超分辨率重建(Keras)


前言

SRGAN 网络是用GAN网络来实现图像超分辨率重建的网络。训练完网络后。只用生成器来重建低分辨率图像。网络结构主要使用生成器(Generator)和判别器(Discriminator)。训练过程不太稳定。一般用于卫星图像,遥感图像的图像重建,人脸图像超分重建。
这里我们使用的高分辨率的数据集 (DIV2K)
数据集下载链接:链接:https://pan.baidu.com/s/1UBle5Cu74TRifcAVz14cDg 提取码:luly
github代码地址:https://github.com/jiantenggei/srgan
重制版代码仓库:https://github.com/jiantenggei/Srgan_

一、SRGAN

1.训练步骤

SRGAN 网络的训练思路如下图所示:
在这里插入图片描述

训练步骤如下:
(1) 将低分辨率输入到生成网络,生成高分辨率图像。
(2) 将高分辨率图像输入的判别网络判别真假,与0和1进行对比
(3) 将原始高分辨率图像和生成的高分辨率图像分别用VGG19 的前9层提取特征,将提取的特征计算loss。
(4). 将loss返回给生成器继续训练。
这就是SRGAN 的训练流程了。
接下来我们一一去实现上述步骤。

2.生成器

生成器网络结构如下图所示:
在这里插入图片描述
生成器主要有两部分构成,第一部分是residual block 残差块(图中红色方块),第二部分是上采样部分(图中蓝色方块)用来上采样特征图。
残差块:包含一个两个3x3的卷积 BN,PReLu
上采样:使用UpSampling2D,这里可能与原模型不同实现
生成器代码如下所示:

# 生成器中的残差块
def res_block_gen(x, kernal_size, filters, strides):
    
    gen = x
    
    x = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(x)
    x = BatchNormalization(momentum = 0.5)(x)
    # Using Parametric ReLU
    x = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(x)
    x = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(x)
    x = BatchNormalization(momentum = 0.5)(x)
        
    x = add([gen, x])
    
    return x

#上采样样块
def up_sampling_block(x, kernal_size, filters, strides):
    x = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(x)
    x = UpSampling2D(size = 2)(x)
    x = LeakyReLU(alpha = 0.2)(x)
    
    return x
#--------------------------------------
# 亚像素卷积上采样块
# 生成器 还是用的 UpSampling2D
# 如果有需要可以自己更改
# -------------------------------------
def SubpixelConv2D(input_shape, scale=4):
    def subpixel_shape(input_shape):
        dims = [input_shape[0],input_shape[1] * scale,input_shape[2] * scale,int(input_shape[3] / (scale ** 2))]
        output_shape = tuple(dims)
        return output_shape
    
    def subpixel(x):
        return tf.compat.v1.depth_to_space(x, scale)
        
    return Lambda(subpixel, output_shape=subpixel_shape)
    
def Generator(input_shape=[128,128,3]):
    gen_input = Input(input_shape)
    x = Conv2D(filters = 64, kernel_size = 9, strides = 1, padding = "same")(gen_input)
    x = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(x)
	    
    gen_x = x
        
    # 16 个残差快
    for index in range(16)
  • 12
    点赞
  • 142
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 21
    评论
SRGAN是一种基于生成对抗网络的超分辨率模型,其核心思路是通过训练一个生成器网络和一个判别器网络,使得生成器网络能够将低分辨率图像转换为高分辨率图像,并且判别器网络能够区分生成的高分辨率图像和真实高分辨率图像之间的差异。 具体实现上,SRGAN模型包含两个部分:生成器网络和判别器网络。生成器网络由一系列卷积层、反卷积层和残差块组成,用于将低分辨率图像转换为高分辨率图像。判别器网络则由一系列卷积层和全连接层组成,用于判别生成的高分辨率图像和真实高分辨率图像之间的差异。 在训练过程中,生成器网络和判别器网络相互博弈,不断优化模型参数,直到最终生成的高分辨率图像能够和真实高分辨率图像无法区分为止。具体来说,训练过程包括以下几个步骤: 1. 准备训练数据。将高分辨率图像和对应的低分辨率图像作为训练数据,其中低分辨率图像可以通过对高分辨率图像进行下采样得到。 2. 训练判别器网络。首先,使用真实高分辨率图像和生成器网络生成的高分辨率图像作为正样本,使用下采样的高分辨率图像作为负样本,训练判别器网络,使其能够区分正样本和负样本。 3. 训练生成器网络。首先,使用生成器网络将低分辨率图像转换为高分辨率图像,然后使用判别器网络判别生成的高分辨率图像的质量。通过最小化生成的高分辨率图像和真实高分辨率图像之间的差异,以及最大化判别器网络判别生成的高分辨率图像为真实高分辨率图像的概率,来优化生成器网络的参数。 4. 微调生成器网络。在训练过程中,生成器网络可能会出现过拟合或者欠拟合的情况,需要通过微调生成器网络的参数来解决这些问题。 SRGAN模型的训练需要大量的高分辨率图像和低分辨率图像作为训练数据,并且需要在GPU上进行训练,因此需要一定的计算资源和时间。但是,SRGAN模型能够生成非常逼真的高分辨率图像,对于一些对图像质量要求较高的应用场景具有很大的帮助。
评论 21
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

__不想写代码__

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值