前言
SRGAN 网络是用GAN网络来实现图像超分辨率重建的网络。训练完网络后。只用生成器来重建低分辨率图像。网络结构主要使用生成器(Generator)和判别器(Discriminator)。训练过程不太稳定。一般用于卫星图像,遥感图像的图像重建,人脸图像超分重建。
这里我们使用的高分辨率的数据集 (DIV2K)
数据集下载链接:链接:https://pan.baidu.com/s/1UBle5Cu74TRifcAVz14cDg 提取码:luly
github代码地址:https://github.com/jiantenggei/srgan
重制版代码仓库:https://github.com/jiantenggei/Srgan_
一、SRGAN
1.训练步骤
SRGAN 网络的训练思路如下图所示:
训练步骤如下:
(1) 将低分辨率输入到生成网络,生成高分辨率图像。
(2) 将高分辨率图像输入的判别网络判别真假,与0和1进行对比
(3) 将原始高分辨率图像和生成的高分辨率图像分别用VGG19 的前9层提取特征,将提取的特征计算loss。
(4). 将loss返回给生成器继续训练。
这就是SRGAN 的训练流程了。 接下来我们一一去实现上述步骤。
2.生成器
生成器网络结构如下图所示:
生成器主要有两部分构成,第一部分是residual block 残差块(图中红色方块),第二部分是上采样部分(图中蓝色方块)用来上采样特征图。
残差块:包含一个两个3x3的卷积 BN,PReLu
上采样:使用UpSampling2D,这里可能与原模型不同实现
生成器代码如下所示:
# 生成器中的残差块
def res_block_gen(x, kernal_size, filters, strides):
gen = x
x = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(x)
x = BatchNormalization(momentum = 0.5)(x)
# Using Parametric ReLU
x = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(x)
x = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(x)
x = BatchNormalization(momentum = 0.5)(x)
x = add([gen, x])
return x
#上采样样块
def up_sampling_block(x, kernal_size, filters, strides):
x = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(x)
x = UpSampling2D(size = 2)(x)
x = LeakyReLU(alpha = 0.2)(x)
return x
#--------------------------------------
# 亚像素卷积上采样块
# 生成器 还是用的 UpSampling2D
# 如果有需要可以自己更改
# -------------------------------------
def SubpixelConv2D(input_shape, scale=4):
def subpixel_shape(input_shape):
dims = [input_shape[0],input_shape[1] * scale,input_shape[2] * scale,int(input_shape[3] / (scale ** 2))]
output_shape = tuple(dims)
return output_shape
def subpixel(x):
return tf.compat.v1.depth_to_space(x, scale)
return Lambda(subpixel, output_shape=subpixel_shape)
def Generator(input_shape=[128,128,3]):
gen_input = Input(input_shape)
x = Conv2D(filters = 64, kernel_size = 9, strides = 1, padding = "same")(gen_input)
x = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(x)
gen_x = x
# 16 个残差快
for index in range(16)