SRGAN的生成器
初始的SRGAN生成器
import math
import tensorflow as tf
from keras.initializers import random_normal
from keras import layers
from keras.applications import VGG19
from keras.models import Model
def SubpixelConv2D(scale=4): #定义一个上采样函数,将小特征图进行拼接成大的
def subpixel_shape(input_shape):
dims = [input_shape[0],
None if input_shape[1] is None else input_shape[1] * scale,
None if input_shape[2] is None else input_shape[2] * scale,
int(input_shape[3] / (scale ** 2))] #上采样,将特征图的长宽放大scale倍
output_shape = tuple(dims) #将可迭代系列(如列表)转换为元组
return output_shape #返回生成的特征图的shape
def subpixel(x):
return tf.depth_to_space(x, scale) #该函数主要用于4D tensor,因此,数据的格式默认为 ‘NHWC’ , 函数将Channel轴数据变换到Hight 和 Width轴。输入的形状是:[batch, height, width, depth],输出的形状为:[batch, height*scale, width*scale, depth/(scale**2)];
#将一个较多通道的特征变成较少通道的特征
return layers.Lambda(subpixel, output_shape=subpixel_shape) #匿名函数层,function=subpixel,制定输出大小为subpixel_shape,为啥指定呢,可能是为了防止出差错
def residual_block(inputs, filters): #定义残差块
x = layers.Conv2D(filters, kernel_size=3, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(inputs) #进行卷积,卷积核大小为3*3,步长为1,边缘填充,使用标准正态分布初始化,标准差为0.02
x = layers.BatchNormalization(momentum=0.5)(x) #进行块的标准化,momentum=0.5表示之前值的权重,执行的是moving_average_value*momentum+value*(1-momentum)
x = layers.advanced_activations.PReLU(shared_axes=[1,2])(x) #带参数的ReLU,会自动学习参数
x = layers.Conv2D(filters, kernel_size=3, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(x)#卷积
x = layers.BatchNormalization(momentum=0.5)(x)
x = layers.Add()([x, inputs]) #add对张量执行求和运算,而concatenate对张量进行串联运算
return x
def deconv2d(inputs): #上采样
x = layers.Conv2D(256, kernel_size=3, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(inputs) #卷积
x = SubpixelConv2D(scale=2)(x) #特征上采样
x = layers.advanced_activations.PReLU(shared_axes=[1,2])(x)
return x
def build_generator(lr_shape, scale_factor, num_residual=16):
#-----------------------------------#
# 获得进行上采用的次数
#-----------------------------------#
upsample_block_num = int(math.log(scale_factor, 2)) #上采样模块的数目,math.log(scale_factor, 2)表示log2(scale_factor);
img_lr = layers.Input(shape=lr_shape) #输入低分辨图片
#--------------------------------------------------------#
# 第一部分,低分辨率图像进入后会经过一个卷积+PRELU函数
#--------------------------------------------------------#
x = layers.Conv2D(64, kernel_size=9, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(img_lr)
x = layers.advanced_activations.PReLU(shared_axes=[1,2])(x)
short_cut = x #short_cut作为最后的残差输入,保证原始信息的充分
#-------------------------------------------------------------#
# 第二部分,经过num_residual个残差网络结构。
# 每个残差网络内部包含两个卷积+标准化+PRELU,还有一个残差边。
#-------------------------------------------------------------#
for _ in range(num_residual): #进行16个残差网络
x = residual_block(x, 64)
x = layers.Conv2D(64, kernel_size=3, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(x)
x = layers.BatchNormalization(momentum=0.5)(x)
x = layers.Add()([x, short_cut])
#-------------------------------------------------------------#
# 第三部分,上采样部分,将长宽进行放大scale倍。
# n次上采样后,变为原来的2**n倍,实现提高分辨率。
#-------------------------------------------------------------#
for _ in range(upsample_block_num):
x = deconv2d(x)
gen_hr = layers.Conv2D(3, kernel_size=9, strides=1, padding='same', activation='tanh')(x) #卷积
return Model(img_lr, gen_hr) #模块生成
def build_vgg():
# 建立VGG19模型,去掉全连接层,权重值应用已经训练好的"imagenet",用于获取图像特征
vgg = VGG19(False, weights="imagenet")
vgg.outputs = [vgg.layers[-2].output] #去掉一个最大池化层,其他层都保留
img = layers.Input(shape=[None,None,3])
img_features = vgg(img)
return Model(img, img_features)
if __name__ == "__main__":
model = build_generator([56,56,3]) #低分辨图像的大小,进入生成器
#model.summary() 能看到输出模型各层的参数状况 #一般需要在开头加model = Sequential() # 顺序模型