学习笔记之——基于tensorflow的VESPCN

之前已经写了挺多的博文关于image 超分了,接下来研究一下video 超分。之前博文已经对VESPCN进行了理论的介绍(《学习笔记之——基于深度学习的图像超分辨率重构》)

之前做的超分都是基于pytorch的,这次这个项目是tensorflow的。。。。然后一开始学习超分的代码是cafe的。。。。。。?本博文先对基于tensorflow的VESPCN做深入的剖析,然后再移植到之前的代码库中,用pytorch实现。

 

 

本博文主要参考代码:https://github.com/zhxl0903/CSCD94_VESPCN

数据集放于文件train中,测试集放于test文档中。

注释并作一定修改的代码如下(部分注释与修改来自课题组的师弟哈~):

main.py

import tensorflow as tf
from model import ESPCN
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '5'

"""
    执行main函数之前首先进行flags的解析,
    也就是说TensorFlow通过设置flags来传递the required parameters to the tf.app.run()in the main.py
    也就是超参数的人为设置,然后存放在 FLAGS 中
    参数如下:
"""
flags = tf.app.flags
FLAGS = flags.FLAGS

"""
       参数名称           默认值             描述
    1.  epoch             1200             1个epoch表示过了1遍训练集中的所有样本,这里表示一共训练1200次
    2.  image_size        32               读入图片的尺寸(不是原图尺寸,而是网络入口要求的图片尺寸(数据子图),见详解
    3.  c_dim             3                图片通道数,一般RGB因此为3
    4.  train_mode        0                作者在此项目中实现了不同的网络训练模型,此参数是各模型代号(0~6)
    5.  scale             3                用于预处理输入图像的尺度因子的大小,即缩小比例
    6.  stride            100              步长,用于裁剪原图生成数据子图,见详解
    7.  checkpoint_dir    checkpoint       文件名,用于(暂时)存放训练所得模型参数
    8.  learning_rate     0.0001           学习速度
    9.  batch_size        128              1次迭代所使用的样本量,即一次抓取多少图片(不是原图,是裁剪过后的数据子图)
    10. result_dir        result           文件名,用于存放结果【不很重要】
    11. test_img          ‘ ’              【不清楚】
    12. load_existing_data False           判断hf数据是否被加载【暂不明白】
"""
#         第一个是参数名称,   第二个参数是默认值,   第三个是参数描述
flags.DEFINE_integer("epoch",          1200,           "Number of epoch")
flags.DEFINE_integer("image_size",     32,             "The size of image input")
flags.DEFINE_integer("c_dim",          3,              "The size of channel")
flags.DEFINE_boolean("is_train",       True,          "if training")
flags.DEFINE_integer("train_mode",     2,
                     "0: Spatial Transformer       1: 9-L Single-Frame ESPCN    2: 9L-E3-MC VESPCN \
                      3: Bicubic(No Training Required)          4: SRCNN \
                      5: Multi-Dir mode for testing mode 2       6: Multi-Dir mode for testing mode 1 \
                      7: FSRCNN ")
flags.DEFINE_integer("scale",          3,              "the size of scale factor for preprocessing input image")
flags.DEFINE_integer("stride",         100,            "the size of stride")
flags.DEFINE_string("checkpoint_dir",  "checkpoint",   "Name of checkpoint directory")
flags.DEFINE_float("learning_rate",    1e-5 ,          "The learning rate")
flags.DEFINE_integer("batch_size",     128,            "the size of batch")
flags.DEFINE_string("result_dir",      "result",       "Name of result directory")
flags.DEFINE_string("test_img",        "",             "test_img")
flags.DEFINE_boolean("load_existing_data",    False,
                     "True iff existing hf data is loaded for training/testing")


def main(_):

    """
    设置配置并启用GPU内存分配增长,tf.ConfigProto()用以配置Session运行参数
    当其参数allow_growth设置为True时,分配器将不会指定所有的GPU内存,而是动态申请显存
    """
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    
    # 检查训练模式是否为3、5或6,is_train是否开启
    # 因为mode3\mode5\mode6是不用训练的,若在这些模式下要求训练,则直接退出
    if FLAGS.train_mode == 3 and FLAGS.is_train:
        print('Error: Bicubic Mode does not require training')
        return
    elif FLAGS.train_mode == 5 and FLAGS.is_train:
        print('Error: Multi-Dir testing mode for Mode 2 does not require training')
        return
    elif FLAGS.train_mode == 6 and FLAGS.is_train:
        print('Error: Multi-Dir testing mode for Mode 1 does not require training')
        return

    """ 
        接下来执行会话,初始化ESPCN对象,开始训练或测试" \
        ESPCN对象:
            1. 基于OOP思想,本项目绝大多数与神经网络有关的超参数、操作都被封装于一个ESPCN类的对象中
            2. ESPCN对象被创建后,会自动依据传入的超参数,进行对应mode的网络的构建
            3. 随后执行train()函数,进行训练/测试 + 保存结果等一系列操作(函数名虽然是train,也可以用于测试)
    """
   # 给Session配置运行参数
    with tf.Session(config=config) as sess:
        # 使用传递进来的人工设定的参数初始化ESPCN对象
        espcn = ESPCN(sess,  #会话信息
                      image_size=FLAGS.image_size,  # 图片尺寸
                      is_train=FLAGS.is_train,      # 是否训练模式
                      train_mode=FLAGS.train_mode,  # 何种训练模式
                      scale=FLAGS.scale,            # 尺度
                      c_dim=FLAGS.c_dim,            # 图片通道数
                      batch_size=FLAGS.batch_size,  # 单次抓取数据量
                      load_existing_data=FLAGS.load_existing_data, # 数据是否读入
                      config=config                 # 会话参数
                      )
        # 开始训练,将所需各参数传递进去

        espcn.train(FLAGS)


if __name__ == '__main__':
    
    # 解析命令行参数;然后调用主函数
    tf.app.run() 

model.py

 

 

 

 

 

 

 

训练时要保证

flags.DEFINE_boolean("is_train",       True,          "if training")

训练的模式

("train_mode", 0, "0: Spatial Transformer 1: 9-L Single-Frame ESPCN\
                     2: 9L-E3-MC VESPCN 3: Bicubic (No Training Required) 4: SRCNN \
                     5: Multi-Dir mode for testing mode 2 6: Multi-Dir mode \
                     for testing mode 1")

 

 

 

补充

关于tf.app.run() 

if __name__ == '__main__':
    tf.app.run() 

处理flag解析,然后执行main函数,而flag解析就是输入的参数

https://blog.csdn.net/helei001/article/details/51859423

关于VESPCN参考https://blog.csdn.net/Cyiano/article/details/78368263?locationNum=4&fps=1

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值