最近在学习SRCNN,阅读代码做好笔记
代码下载链接https://github.com/tegg89/SRCNN-Tensorflow
下面开始
from model import SRCNN
from utils import input_setup
import numpy as np
import tensorflow as tf
import pprint
import os
flags = tf.app.flags
flags.DEFINE_integer("epoch", 2000,"训练多少波")
#flags.DEFINE_integer("batch_size", 128, "The size of batch images [128]")
#一开始将batch size设为128和64,不仅参数初始loss很大,而且往往一段时间后训练就发散
#batch中每个样本产生梯度竞争可能比较激烈,所以导致了收敛过慢
#后来改回了128
flags.DEFINE_integer("batch_size", 128, "batch size")
flags.DEFINE_integer("image_size", 33, "图像使用的尺寸")
flags.DEFINE_integer("label_size", 21, "label_制作的尺寸")
#学习率文中设置为 前两层1e-4 第三层1e-5
#SGD+指数学习率10-2作为初始
flags.DEFINE_float("learning_rate", 1e-2, "学习率")
flags.DEFINE_integer("c_dim", 1, "图像维度")
flags.DEFINE_integer("scale", 3, "sample的scale大小")
#stride训练采用14,测试采用21
flags.DEFINE_integer("stride", 21 , "步长为14或者21")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "checkpoint directory名字")
flags.DEFINE_string("sample_dir", "sample", "sample directory名字")
flags.DEFINE_boolean("is_train", False, "True for training, False for testing")#测试
#flags.DEFINE_boolean("is_train", True, "True for training, False for testing")#训练
FLAGS = flags.FLAGS
pp = pprint.PrettyPrinter()
def main(_):
pp.pprint(flags.FLAGS.__flags)
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
with tf.Session() as sess:
srcnn = SRCNN(sess,
image_size=FLAGS.image_size,
label_size=FLAGS.label_size,
batch_size=FLAGS.batch_size,
c_dim=FLAGS.c_dim,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir)
srcnn.train(FLAGS)
if __name__ == '__main__':
tf.app.run()