SRGAN_tensorflow_code

本文详细介绍了在TensorFlow中实现SRGAN的过程,包括SRGAN网络的构成、训练步骤及VGG网络的引入,旨在以通俗易懂的方式解析超分辨图像生成的代码实现。文章分为test mode、inference mode和train mode三个部分,分别对应模型测试、任意图像超分辨和模型训练。在训练模式下,首先使用SRResnet训练生成器,然后训练SRGAN网络,并通过VGG网络进一步优化细节。
摘要由CSDN通过智能技术生成

SRGAN_tensorflow包含两个网络,分别是SRGAN网络和VGG网络。其中,SRGAN网络由生成器网络generator和判决器网络discriminator组成。根据原始论文(Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network)框架,其训练过程如下;

  1. 在SRResnet task下训练SRResnet网络,得到最初的generator网络。输入是HR降采样后的低分LR图片,输出是低分LR经过超分后产生得超分图片SR,HR作为ground truth和SR一起衡量content loss。循环100万次;
  2. 保留1中generator网络所有参数,随机初始化discriminator网络参数,训练SRGAN网络。使用MSE衡量生成器内容损失content loss,使用交叉验证衡量生成器对抗损失adversarial loss(由生成图片输入判决器判决结果决定),使用交叉验证衡量判决器的discriminator loss(包含discrim_fake_loss和discrim_real_loss,前者由生成图片经过判决器的输出结果决定,后者由真实图片经过判决器的输出结果决定)。循环50万次;
  3. 保留2中SRGAN网络所有参数,引入VGG网络。使用VGG网络分别提取真实图片HR和生成图片SR的特征,计算两个特征差作为损失函数,更好的保留图像细节信息。循环20万次

本文尽可能以通俗易懂的方式介绍下超分辨TensorFlow实现的代码。从main脚本文件开始,首先要需要输入命令行参数作为整个srgan的超参数。条件语句用来检测某些关键参数如输入输出路径是否传入或存在。在这些准备好之后,程序就要开始它的核心工作了,这里核心工作共有三个,test mode(利用训练好的generator模型在LR上测试),inference mode(利用训练好的generator在任意数据集上测试),以及train mode(训练模式包含两个task,SRResnet和SRGAN ,每个task下的perceptual_mode 有三种模式可选,VGG54,VGG22,MSE)。在train mode模式下,上述过程1工作时,task=SRResnet,perceptual_mode=MSE;过程2工作时,task=SRGAN,perceptual_mode=MSE;过程3工作时,task=SRGAN,perceptual_mode=VGG54/22。每个模式的详细功能分别介绍如下。

1.test mode。

任务是把低分辨数据集放到训练好的模型上测试并保存结果。进入该模式第一步就是调用mode.py中的test_data_loader(FLAGS)函数读入数据,os.listdir() 方法用于返回路径FLAGS.input_dir_LR下包含的所有文件的名字的列表。将文件名列表分别和各自对应的具体路径拼接在一起,得到两个包含文件名路径的完整地址列表image_list_LR和image_list_HR。接着,利用for循环调用preprocess_test(name, mode)函数读入image_list_LR和image_list_HR对应的图像列表,得到image_LR(map(0,1))和image_HR(map(-1,1))。最终返回一个大的列表test_data,包含两个完整路径和两个完整图像列表。同时定义四个占位符,后面使用时直接给占位符赋值。

    inputs_raw = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='inputs_raw')
    targets_raw = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='targets_raw')
    path_LR = tf.placeholder(tf.string, shape=[], name='path_LR')
    path_HR = tf.placeholder(tf.string, shape=[], name='path_HR')

接下来根据工作模式将数据送入generator,这里的inputs_raw占位符在session中会被赋以真实值test_data.inputs

    with tf.variable_scope('generator'):
        if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet':
            gen_output = generator(inputs_raw, 3, reuse=False, FLAGS=FLAGS)
        else:
            raise NotImplementedError('Unknown task!!')

generator是实现超分网络的模型,主要由多个残差块和卷积层组成,详细信息如下:

# Definition of the generator
def generator(gen_inputs, gen_output_channels, reuse=False, FLAGS=None):
    # Check the flag
    if FLAGS is None:
        raise  ValueError('No FLAGS is provided for generator')

    # The Bx residual blocks
    def residual_block(inputs, output_channel, stride, scope):
        with tf.variable_scope(scope):
            #3x3kernel
            net = conv2(inputs, 3, output_channel, stride, use_bias=False, scope='conv_1')
            net = batchnorm(net, FLAGS.is_training)
            net = prelu_tf(net)
            #3x3kernel
            net = conv2(net, 3, output_channel, stride, use_bias=False, scope='conv_2')
            net = batchnorm(net, FLAGS.is_training)
            #skip connection
            net = net + inputs

        return net


    with tf.variable_scope('generator_unit', reuse=reuse):
        # The input layer
        with tf.variable_scope('input_stage'):
            # kernel size 9x9 kernel count(feature map) 64 stride 1
            net = conv2(gen_inputs, 9, 64, 1, scope='conv')
            net = prelu_tf(net)

        stage1_output = net

        # The residual block parts
        for i in range(1, FLAGS.num_res
  • 4
    点赞
  • 40
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值