好像还挺好玩的GAN重制版2——Keras搭建SRGAN平台进行图片超分辨率提升

学习前言

我又死了我又死了我又死了!
在这里插入图片描述

源码下载地址

https://github.com/bubbliiiing/srgan-keras

喜欢的可以点个star噢。

网络构建

一、什么是SRGAN

SRGAN出自论文Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network。

如果将SRGAN看作一个黑匣子,其主要的功能就是输入一张低分辨率图片,生成高分辨率图片。
在这里插入图片描述
该文章提到,普通的超分辨率模型训练网络时只用到了均方差作为损失函数,虽然能够获得很高的峰值信噪比,但是恢复出来的图像通常会丢失高频细节

SRGAN利用感知损失(perceptual loss)和对抗损失(adversarial loss)来提升恢复出的图片的真实感

二、生成网络的构建

在这里插入图片描述
生成网络的构成如上图所示,生成网络的作用是输入一张低分辨率图片,生成高分辨率图片。

SRGAN的生成网络由三个部分组成。
1、低分辨率图像进入后会经过一个卷积+RELU函数
2、然后经过B个残差网络结构,每个残差结构都包含两个卷积+标准化+RELU,还有一个残差边。
3、然后进入上采样部分,在经过两次上采样后,原图的高宽变为原来的4倍,实现分辨率的提升

前两个部分用于特征提取,第三部分用于提高分辨率。

def residual_block(inputs, filters):
    x = layers.Conv2D(filters, kernel_size=3, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(inputs)
    x = layers.BatchNormalization(momentum=0.5)(x)
    x = layers.advanced_activations.PReLU(shared_axes=[1,2])(x)

    x = layers.Conv2D(filters, kernel_size=3, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(x)
    x = layers.BatchNormalization(momentum=0.5)(x)
    x = layers.Add()([x, inputs])
    return x

def deconv2d(inputs):
    x = layers.Conv2D(256, kernel_size=3, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(inputs)
    x = SubpixelConv2D(scale=2)(x)
    x = layers.advanced_activations.PReLU(shared_axes=[1,2])(x)
    return x

def build_generator(lr_shape, scale_factor, num_residual=16):
    #-----------------------------------#
    #   获得进行上采用的次数
    #-----------------------------------#
    upsample_block_num = int(math.log(scale_factor, 2))
    img_lr = layers.Input(shape=lr_shape)

    #--------------------------------------------------------#
    #   第一部分,低分辨率图像进入后会经过一个卷积+PRELU函数
    #--------------------------------------------------------#
    x = layers.Conv2D(64, kernel_size=9, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(img_lr)
    x = layers.advanced_activations.PReLU(shared_axes=[1,2])(x)

    short_cut = x
    #-------------------------------------------------------------#
    #   第二部分,经过num_residual个残差网络结构。
    #   每个残差网络内部包含两个卷积+标准化+PRELU,还有一个残差边。
    #-------------------------------------------------------------#
    for _ in range(num_residual):
        x = residual_block(x, 64)

    x = layers.Conv2D(64, kernel_size=3, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(x)
    x = layers.BatchNormalization(momentum=0.5)(x)
    x = layers.Add()([x, short_cut])

    #-------------------------------------------------------------#
    #   第三部分,上采样部分,将长宽进行放大。
    #   两次上采样后,变为原来的4倍,实现提高分辨率。
    #-------------------------------------------------------------#
    for _ in range(upsample_block_num):
        x = deconv2d(x)

    gen_hr = layers.Conv2D(3, kernel_size=9, strides=1, padding='same', activation='tanh')(x)

    return Model(img_lr, gen_hr)

三、判别网络的构建

在这里插入图片描述
判别网络的构成如上图所示:

SRGAN的判别网络由不断重复的 卷积+LeakyRELU和标准化 组成。
对于判断网络来讲,它的目的是判断输入图片的真假,它的输入是图片,输出是判断结果

判断结果处于0-1之间,利用接近1代表判断为真图片,接近0代表判断为假图片。

判断网络的构建和普通卷积网络差距不大,都是不断的卷积对图片进行下采用,在多次卷积后,最终接一次全连接判断结果。

实现代码如下:

def d_block(inputs, filters, strides=1):
    x = layers.Conv2D(filters, kernel_size=3, strides=strides, padding='same', kernel_initializer = random_normal(stddev=0.02))(inputs)
    x = layers.BatchNormalization(momentum=0.5)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    return x

def build_discriminator(hr_shape):
    inputs = layers.Input(shape=hr_shape)

    x = layers.Conv2D(64, kernel_size=3, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(inputs)
    x = layers.LeakyReLU(alpha=0.2)(x)

    x = d_block(x, 64, strides=2)
    x = d_block(x, 128)
    x = d_block(x, 128, strides=2)
    x = d_block(x, 256)
    x = d_block(x, 256, strides=2)
    x = d_block(x, 512)
    x = d_block(x, 512, strides=2)

    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(1024, kernel_initializer = random_normal(stddev=0.02))(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    validity = layers.Dense(1, activation='sigmoid', kernel_initializer = random_normal(stddev=0.02))(x)
    return Model(inputs, validity)

训练思路

SRGAN的训练可以分为生成器训练和判别器训练:
每一个step中一般先训练判别器,然后训练生成器。

一、判别器的训练

训练判别器的时候我们希望判别器可以判断输入图片的真伪,因此我们的输入就是真图片、假图片和它们对应的标签

因此判别器的训练步骤如下:

1、随机选取batch_size个真实高分辨率图片。
2、利用resize后的低分辨率图片,传入到Generator中生成batch_size个虚假高分辨率图片。
3、真实图片的label为1,虚假图片的label为0,将真实图片和虚假图片当作训练集传入到Discriminator中进行训练。

二、生成器的训练

训练生成器的时候我们希望生成器可以生成极为真实的假图片。因此我们在训练生成器需要知道判别器认为什么图片是真图片。

因此生成器的训练步骤如下:

1、将低分辨率图像传入生成模型,得到虚假高分辨率图像,将虚假高分辨率图像获得判别结果与1进行对比得到loss。(与1对比的意思是,让生成器根据判别器判别的结果进行训练)。
2、将真实高分辨率图像和虚假高分辨率图像传入VGG网络,获得两个图像的特征,通过这两个图像的特征进行比较获得loss

在这里插入图片描述

利用SRGAN生成图片

SRGAN的库整体结构如下:
在这里插入图片描述

一、数据集的准备

在训练前需要准备好数据集,数据集保存在datasets文件夹里面。
在这里插入图片描述

二、数据集的处理

打开txt_annotation.py,默认指向根目录下的datasets。运行txt_annotation.py。
此时生成根目录下面的train_lines.txt。
在这里插入图片描述

三、模型训练

在完成数据集处理后,运行train.py即可开始训练。
在这里插入图片描述
训练过程中,可在results文件夹内查看训练效果:
在这里插入图片描述

  • 22
    点赞
  • 134
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 74
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 74
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Bubbliiiing

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值