SRGAN的生成器

SRGAN的生成器

初始的SRGAN生成器

import math
import tensorflow as tf
from keras.initializers import random_normal
from keras import layers
from keras.applications import VGG19
from keras.models import Model

def SubpixelConv2D(scale=4):    #定义一个上采样函数,将小特征图进行拼接成大的
    def subpixel_shape(input_shape):
        dims = [input_shape[0],
                None if input_shape[1] is None else input_shape[1] * scale,
                None if input_shape[2] is None else input_shape[2] * scale,
                int(input_shape[3] / (scale ** 2))]   #上采样,将特征图的长宽放大scale倍
        output_shape = tuple(dims)   #将可迭代系列(如列表)转换为元组
        return output_shape         #返回生成的特征图的shape

    def subpixel(x):
        return tf.depth_to_space(x, scale)    #该函数主要用于4D tensor,因此,数据的格式默认为 ‘NHWC’ , 函数将Channel轴数据变换到Hight 和 Width轴。输入的形状是:[batch, height, width, depth],输出的形状为:[batch, height*scale, width*scale, depth/(scale**2)];
                                              #将一个较多通道的特征变成较少通道的特征

    return layers.Lambda(subpixel, output_shape=subpixel_shape)  #匿名函数层,function=subpixel,制定输出大小为subpixel_shape,为啥指定呢,可能是为了防止出差错


def residual_block(inputs, filters):    #定义残差块     
    x = layers.Conv2D(filters, kernel_size=3, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(inputs)             #进行卷积,卷积核大小为3*3,步长为1,边缘填充,使用标准正态分布初始化,标准差为0.02
    x = layers.BatchNormalization(momentum=0.5)(x)    #进行块的标准化,momentum=0.5表示之前值的权重,执行的是moving_average_value*momentum+value*(1-momentum)
    x = layers.advanced_activations.PReLU(shared_axes=[1,2])(x)  #带参数的ReLU,会自动学习参数

    x = layers.Conv2D(filters, kernel_size=3, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(x)#卷积
    x = layers.BatchNormalization(momentum=0.5)(x)
    x = layers.Add()([x, inputs])     #add对张量执行求和运算,而concatenate对张量进行串联运算
    return x

def deconv2d(inputs):         #上采样
    x = layers.Conv2D(256, kernel_size=3, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(inputs) #卷积
    x = SubpixelConv2D(scale=2)(x)    #特征上采样
    x = layers.advanced_activations.PReLU(shared_axes=[1,2])(x)
    return x

def build_generator(lr_shape, scale_factor, num_residual=16):
    #-----------------------------------#
    #   获得进行上采用的次数
    #-----------------------------------#
    upsample_block_num = int(math.log(scale_factor, 2))  #上采样模块的数目,math.log(scale_factor, 2)表示log2(scale_factor);
    img_lr = layers.Input(shape=lr_shape)    #输入低分辨图片

    #--------------------------------------------------------#
    #   第一部分,低分辨率图像进入后会经过一个卷积+PRELU函数
    #--------------------------------------------------------#
    x = layers.Conv2D(64, kernel_size=9, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(img_lr)
    x = layers.advanced_activations.PReLU(shared_axes=[1,2])(x)

    short_cut = x             #short_cut作为最后的残差输入,保证原始信息的充分
    #-------------------------------------------------------------#
    #   第二部分,经过num_residual个残差网络结构。
    #   每个残差网络内部包含两个卷积+标准化+PRELU,还有一个残差边。
    #-------------------------------------------------------------#
    for _ in range(num_residual):     #进行16个残差网络
        x = residual_block(x, 64)

    x = layers.Conv2D(64, kernel_size=3, strides=1, padding='same', kernel_initializer = random_normal(stddev=0.02))(x)
    x = layers.BatchNormalization(momentum=0.5)(x)
    x = layers.Add()([x, short_cut])    

    #-------------------------------------------------------------#
    #   第三部分,上采样部分,将长宽进行放大scale倍。
    #   n次上采样后,变为原来的2**n倍,实现提高分辨率。
    #-------------------------------------------------------------#
    for _ in range(upsample_block_num):
        x = deconv2d(x)  

    gen_hr = layers.Conv2D(3, kernel_size=9, strides=1, padding='same', activation='tanh')(x)  #卷积

    return Model(img_lr, gen_hr)      #模块生成

def build_vgg():
    # 建立VGG19模型,去掉全连接层,权重值应用已经训练好的"imagenet",用于获取图像特征
    vgg = VGG19(False, weights="imagenet")
    vgg.outputs = [vgg.layers[-2].output]   #去掉一个最大池化层,其他层都保留

    img = layers.Input(shape=[None,None,3])
    img_features = vgg(img)

    return Model(img, img_features)    


if __name__ == "__main__":
    model = build_generator([56,56,3])   #低分辨图像的大小,进入生成器
    #model.summary() 能看到输出模型各层的参数状况       #一般需要在开头加model = Sequential() # 顺序模型

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值