超分辨率重构之SRCNN整理总结(七)

         到此为止关于超分重建的理论部分八成已经作结,关于这个tensorflow版本的SRCNN的代码解读不知道究竟需要写到什么程度才可以完美收官。大家也都明白,这个东西若写太细,略显冗杂;若写太粗,略显不够明析。反正吧,尽可能的写清楚写明细。下面是我的GitHub代码仓库:https://github.com/XiaoYunChaos,关于这篇的代码随后完整作结后我会上传至仓库,供大家讨论学习,欢迎star哦!

SRCNN(tensorflow)详解分析

  • 【1】首先,介绍一下项目结构:

              main.py 定义训练和测试参数,此后由设定的参数进行训练或测试。

    model.py是模型文件以类的方式实现

    utils.py是用来封装项目中的函数作为函数池

    psnr.py是用来做评价函数的,功能就是进行计算评价指标

              checkpoint文件夹是用来保训练模型,即chekpoint的路径

              sample文件夹是样本路径

              Train文件夹是训练集路径

              Test文件夹是测试集路径,包含Set5与Set14

        在看懂代码前,一定要明白一件事就是我们每一次训练实际上是训练图片的大小和输出图片等的大小等参数的设置。项目除了一般的预处理操作,还需要将图片分割,最后的训练完还做实验的时候还需要将图片结合起来。

  • 【2】main.py

        功能:定义训练和测试参数,包括:batchSize、学习率、步长stride、训练、测试等。

函数运行开启:

if __name__ == '__main__':
    # main()
    tf.app.run()

随后tf.app运行,此时涉及相关参数:

flags = tf.app.flags
#第一个是参数名称,第二个参数是默认值,第三个是参数描述
flags.DEFINE_integer("epoch", 15000, "训练多少波Number of epoch [15000]")
#flags.DEFINE_integer("batch_size", 128, "The size of batch images [128]")
flags.DEFINE_integer("batch_size", 128, "batch size")
#一开始将batch size设为128和64,不仅参数初始loss很大,而且往往一段时间后训练就发散
#batch中每个样本产生梯度竞争可能比较激烈,所以导致了收敛过慢
#后来就改回了128
flags.DEFINE_integer("image_size", 33, "图像使用的尺寸 The size of image to use [33]")
flags.DEFINE_integer("label_size", 21, "label_制作的尺寸 The size of label to produce [21]")
#学习率文中设置为 前两层1e-4 第三层1e-5
#SGD+指数学习率10-2作为初始
flags.DEFINE_float("learning_rate", 1e-4, "学习率 The learning rate of gradient descent algorithm [1e-4]")
flags.DEFINE_integer("c_dim", 1, "图像维度 Dimension of image color. [1]")
flags.DEFINE_integer("scale", 3, "sample的scale大小 The size of scale factor for preprocessing input image [3]")
#stride训练采用14,测试采用21
flags.DEFINE_integer("stride", 14, "步长为14或者21 The size of stride to apply input image [14]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "名字 Name of checkpoint directory [checkpoint]")
flags.DEFINE_string("sample_dir", "sample", "名字 Name of sample directory [sample]")
flags.DEFINE_boolean("is_train", True, "True for training, False for testing [True]")#训练
#flags.DEFINE_boolean("is_train", False, "True for training, False for testing")#测试
FLAGS = flags.FLAGS
#第一句是赋值,将前面的一系列参数赋值给FLAGS。
#第二句是创建了一个打印的类,这样就可以调用pp的函数了。
pp = pprint.PrettyPrinter()

此时需要注意这些参数:

  • epoch:迭代次数
  • batch_size:批处理参数
  • image_size:图像大小
  • label_size:高分辨率图像大小,即真实标签的大小
  • learning_rate:学习率
  • c_dim:图像颜色维度
  • scale:缩放倍数
  • stride:卷积步长
  • checkpoint_dir:模型保存路径
  • sample_dir:样本路径
  • is_train:是否训练
     
  • 【3】main函数

CPU版本:

def main(_): #CPU版本
  pp.pprint(flags.FLAGS.__flags)
  #路径检查,没有则创建
  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)
  #tf的相关参数传入及srcnn模型训练或测试
  with tf.Session() as sess:  
    #new出一个类对象,这个对象你可以理解为这个三层神经网络
    srcnn = SRCNN(sess, 
                  image_size=FLAGS.image_size, 
                  label_size=FLAGS.label_size, 
                  batch_size=FLAGS.batch_size,
                  c_dim=FLAGS.c_dim, 
                  checkpoint_dir=FLAGS.checkpoint_dir,
                  sample_dir=FLAGS.sample_dir)
    #训练模型
    srcnn.train(FLAGS)

 

GPU版本:

def main(_): #GPU版本:
  pp.pprint(flags.FLAGS.__flags)
  #os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  #主函数验证路径是否存在,如果不存在就创造一个
  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  with tf.Session(config=config) as sess:
    
    #sess = tf.Session()
    srcnn = SRCNN(sess, 
                  image_size=FLAGS.image_size, 
                  label_size=FLAGS.label_size, 
                  batch_size=FLAGS.batch_size,
                  c_dim=FLAGS.c_dim,
                  #图像维度 
                  checkpoint_dir=FLAGS.checkpoint_dir,
                  sample_dir=FLAGS.sample_dir)

    srcnn.train(FLAGS)
    print(srcnn.train(FLAGS))

        GPU版本与CPU版本代码理解无多大区别,就是在项目部署上可能不一样,GPU的存在有什么好处呢,说白了就是模型训练加速器,可以更快更高效的将模型训练出来,对于GPU的相关笔记随后再做解释吧,你只要把CPU代码理解了,其他的都是锦上添花。

        上述main函数可以说是已经将项目框架跑完了,随后就是一些细节上的理解和处理了。

  • 【4】model.py

from utils import (
  read_data, 
  input_setup, 
  imsave,
  merge
)

import time
import os
import matplotlib.pyplot as plt

import numpy as np
import tensorflow as tf

try:
  xrange
except:
  xrange = range

class SRCNN(object):

  def __init__(self, 
               sess, 
               image_size=33,
               label_size=21, 
               batch_size=128,
               c_dim=1, 
               checkpoint_dir=None, 
               sample_dir=None):

    self.sess = sess
    self.is_grayscale = (c_dim == 1)
    self.image_size = image_size
    self.label_size = label_size
    self.batch_size = batch_size

    self.c_dim = c_dim

    self.checkpoint_dir = checkpoint_dir
    self.sample_dir = sample_dir
    self.build_model()
#搭建网络
  def build_model(self):   #三层网络结构
    self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, self.c_dim], name='images')
    self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, self.c_dim], name='labels')
    #第一层CNN:对输入图片的特征提取。(9 x 9 x 64卷积核)
    #第二层CNN:对第一层提取的特征的非线性映射(1 x 1 x 32卷积核)
    #第三层CNN:对映射后的特征进行重建,生成高分辨率图像(5 x 5 x 1卷积核)
    #权重    
    self.weights = {
      #论文中为提高训练速度的设置 n1=32 n2=16
      'w1': tf.Variable(tf.random_normal([9, 9, 1, 64], stddev=1e-3), name='w1'),
      'w2': tf.Variable(tf.random_normal([1, 1, 64, 32], stddev=1e-3), name='w2'),
      'w3': tf.Variable(tf.random_normal([5, 5, 32, 1], stddev=1e-3), name='w3')
    }
    self.biases = {
      'b1': tf.Variable(tf.zeros([64]), name='b1'),
      'b2': tf.Variable(tf.zeros([32]), name='b2'),
      'b3': tf.Variable(tf.zeros([1]), name='b3')
    }

    self.pred = self.model()
    # Loss function (MSE)以MSE为损失函数
    self.loss = tf.reduce_mean(tf.square(self.labels - self.pred))
    #主函数调用(训练或测试)
    self.saver = tf.train.Saver()
#训练
  def train(self, config):
    if config.is_train:#判断是否为训练(main传入)
      input_setup(self.sess, config)
    else:
      nx, ny = input_setup(self.sess, config)
	#训练为checkpoint下train.h5
    #测试为checkpoint下test.h5
    if config.is_train:     
      data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "train.h5")
    else:
      data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "test.h5")
	#训练数据标签
    train_data, train_label = read_data(data_dir)
	#读取.h5文件(由测试和训练决定)
    # Stochastic gradient descent with the standard backpropagation
    self.train_op = tf.train.GradientDescentOptimizer(config.learning_rate).minimize(self.loss)

    tf.global_variables_initializer().run()
    
    counter = 0
    start_time = time.time()

    if self.load(self.checkpoint_dir):
      print(" [*] Load SUCCESS")
    else:
      print(" [!] Load failed...")
	#训练
    if config.is_train:
      print("Training...")

      for ep in xrange(config.epoch):#迭代次数的循环
      	#以batch为单元
        # Run by batch images
        batch_idxs = len(train_data) // config.batch_size
        for idx in xrange(0, batch_idxs):
          batch_images = train_data[idx*config.batch_size : (idx+1)*config.batch_size]
          batch_labels = train_label[idx*config.batch_size : (idx+1)*config.batch_size]

          counter += 1
          _, err = self.sess.run([self.train_op, self.loss], feed_dict={self.images: batch_images, self.labels: batch_labels})

          if counter % 10 == 0:#10的倍数的step显示
            print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \
              % ((ep+1), counter, time.time()-start_time, err))

          if counter % 500 == 0:#500的倍数step储存
            self.save(config.checkpoint_dir, counter)
	#测试
    else:
      print("Testing...")

      result = self.pred.eval({self.images: train_data, self.labels: train_label})

      result = merge(result, [nx, ny])
      result = result.squeeze()#除去size为1的维度
      #result= exposure.adjust_gamma(result, 1.07)#调暗一些
      image_path = os.path.join(os.getcwd(), config.sample_dir)
      image_path = os.path.join(image_path, "test_image.png")
      imsave(result, image_path)

  def model(self):
  #strides在官方定义中是一个一维具有四个元素的张量,其规定前后必须为1,所以我们可以改的是中间两个数,中间两个数分别代表了水平滑动和垂直滑动步长值。
    conv1 = tf.nn.relu(tf.nn.conv2d(self.images, self.weights['w1'], strides=[1,1,1,1], padding='VALID') + self.biases['b1'])
    conv2 = tf.nn.relu(tf.nn.conv2d(conv1, self.weights['w2'], strides=[1,1,1,1], padding='VALID') + self.biases['b2'])
    conv3 = tf.nn.conv2d(conv2, self.weights['w3'], strides=[1,1,1,1], padding='VALID') + self.biases['b3']
    return conv3

  def save(self, checkpoint_dir, step):
    model_name = "SRCNN.model"
    model_dir = "%s_%s" % ("srcnn", self.label_size)
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)#再一次确定路径为 checkpoint->srcnn_21下

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    self.saver.save(self.sess,
                    os.path.join(checkpoint_dir, model_name), #文件名为SRCNN.model-迭代次数
                    global_step=step)

  def load(self, checkpoint_dir):
    print(" [*] Reading checkpoints...")
    model_dir = "%s_%s" % ("srcnn", self.label_size)
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
#路径为checkpoint->srcnn_labelsize(21)
#加载路径下的模型(.meta文件保存当前图的结构; 
#.index文件保存当前参数名; .data文件保存当前参数值)
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)  
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
        #saver.restore()函数给出model.-n路径后会自动寻找参数名-值文件进行加载
    
        return True
    else:
        return False

训练方式:SGD的效果更好

  • 【5】utils.py
"""
Scipy version > 0.18 is needed, due to 'mode' option from scipy.misc.imread function
"""

import os
import glob#导入glob库,作用是类似于系统的文件路径匹配查询
import h5py#h5py库,主要用于读取或创建datasets或groups
import random
import matplotlib.pyplot as plt

from PIL import Image  # for loading images as YCbCr format
import scipy.misc#该库主要用于将数组保存成图像形式
import scipy.ndimage#该库用于图像处理
import numpy as np

import tensorflow as tf

try:
  xrange#处理异常中断
except:
  xrange = range
  
FLAGS = tf.app.flags.FLAGS#命令行参数传递

def read_data(path):#读取.h5文件的data和label数据,转化np.array格式
  """
  Read h5 format data file
  读取h5格式数据文件,用于训练或者测试
  参数:
    路径: 文件
    data.h5 包含训练输入
    label.h5 包含训练输出
  Args:
    path: file path of desired file
    data: '.h5' file format that contains train data values
    label: '.h5' file format that contains train label values
  """
  with h5py.File(path, 'r') as hf:#读取h5格式数据文件(用于训练或测试)
    data = np.array(hf.get('data'))
    label = np.array(hf.get('label'))
    return data, label

def preprocess(path, scale=3):#定义预处理函数
#(1)读取灰度图像;
#(2)modcrop;
#(3)归一化;
#(4)两次bicubic interpolation

返回input_ ,label_

make_data(sess,data,label)**
作用:将data(checkpoint下的train.h5 或test.h5)利用h5的create_dataset 写入
  """
  #对路径下的image裁剪成scale整数倍,再对image缩小1/scale倍后,放大scale倍以得到低分辨率图input_,调整尺寸后的image为高分辨率图label_
  #image = imread(path, is_grayscale=True)
  #label_ = modcrop(image, scale)
  Preprocess single image file 
    (1) Read original image as YCbCr format (and grayscale as default)
    (2) Normalize
    (3) Apply image file with bicubic interpolation

  Args:
    path: file path of desired file
    input_: image applied bicubic interpolation (low-resolution)
    label_: image with original resolution (high-resolution)
  """
  image = imread(path, is_grayscale=True)
  label_ = modcrop(image, scale)

  # Must be normalized
  image = image / 255.
  label_ = label_ / 255.

  input_ = scipy.ndimage.interpolation.zoom(label_, (1./scale), prefilter=False)
  input_ = scipy.ndimage.interpolation.zoom(input_, (scale/1.), prefilter=False)

  return input_, label_

def prepare_data(sess, dataset):#作用:返回data是训练集或测试集bmp格式的图像
#(1)参数说明:dataset是train dataset 或 test dataset
#(2)glob.glob得到所有的训练集或是测试集图像
  """
  Args:
    dataset: choose train dataset or test dataset
    
    For train dataset, output data would be ['.../t1.bmp', '.../t2.bmp', ..., '.../t99.bmp']
  """
  if FLAGS.is_train:
    filenames = os.listdir(dataset)
    data_dir = os.path.join(os.getcwd(), dataset)
    data = glob.glob(os.path.join(data_dir, "*.bmp"))
    #(2)glob.glob得到所有的训练集或是测试集图像
  else:
  #确定测试数据集合的文件夹为Set5
    data_dir = os.path.join(os.sep, (os.path.join(os.getcwd(), dataset)), "Set5")
    data = glob.glob(os.path.join(data_dir, "*.bmp"))

  return data

def make_data(sess, data, label):
  """
  Make input data as h5 file format
  Depending on 'is_train' (flag value), savepath would be changed.
  """
  #把数据保存成.h5格式
  if FLAGS.is_train:
    savepath = os.path.join(os.getcwd(), 'checkpoint/train.h5')
  else:
    savepath = os.path.join(os.getcwd(), 'checkpoint/test.h5')

  with h5py.File(savepath, 'w') as hf:
    hf.create_dataset('data', data=data)
    hf.create_dataset('label', data=label)

def imread(path, is_grayscale=True):#目的:读取指定路径的图像
  """
  Read image using its path.
  Default value is gray-scale, and image is read by YCbCr format as the paper said.
  """
  #读指定路径的图像
  if is_grayscale:
    return scipy.misc.imread(path, flatten=True, mode='YCbCr').astype(np.float)
  else:
    return scipy.misc.imread(path, mode='YCbCr').astype(np.float)

def modcrop(image, scale=3):
#把图像的长和宽都变成scale的倍数
  """
  To scale down and up the original image, first thing to do is to have no remainder while scaling operation.
  
  We need to find modulo of height (and width) and scale factor.
  Then, subtract the modulo from height (and width) of original image size.
  There would be no remainder even after scaling operation.
  """
  if len(image.shape) == 3:
    h, w, _ = image.shape
    h = h - np.mod(h, scale)
    w = w - np.mod(w, scale)
    image = image[0:h, 0:w, :]
  else:
    h, w = image.shape
    h = h - np.mod(h, scale)
    w = w - np.mod(w, scale)
    image = image[0:h, 0:w]
  return image
  #把result变为和origin一样的大小

def input_setup(sess, config):#功能:读取train set or test set ;做sub-images;保存成h5文件
  """
  Read image files and make their sub-images and saved them as a h5 file format.
  """
  #global nx#后加
  #global ny#后加
  #读图像集,制作子图并保存为h5文件格式
  # 读取数据路径
  # Load data path
  if config.is_train:
    data = prepare_data(sess, dataset="Train")
  else:
    data = prepare_data(sess, dataset="Test")

  sub_input_sequence = []
  sub_label_sequence = []
  padding = abs(config.image_size - config.label_size) / 2 # 6
#padding=0;#修改padding值,测试效果
  #训练
  if config.is_train:
    for i in xrange(len(data)):#一幅图作为一个data
      input_, label_ = preprocess(data[i], config.scale)
#得到data[]的LR和HR图input_和label_
      if len(input_.shape) == 3:
      if len(input_.shape) == 3:
        h, w, _ = input_.shape
      else:
        h, w = input_.shape
#把input_和label_分割成若干自图sub_input和sub_label
      for x in range(0, h-config.image_size+1, config.stride):
        for y in range(0, w-config.image_size+1, config.stride):
          sub_input = input_[x:x+config.image_size, y:y+config.image_size] # [33 x 33]
          sub_label = label_[x+int(padding):x+int(padding)+config.label_size, y+int(padding):y+int(padding)+config.label_size] # [21 x 21]

          # Make channel value
          sub_input = sub_input.reshape([config.image_size, config.image_size, 1])  
          #按image size大小重排 因此 imgae_size应为33 而label_size应为21
          sub_label = sub_label.reshape([config.label_size, config.label_size, 1])

          sub_input_sequence.append(sub_input)
          #在sub_input_sequence末尾加sub_input中元素 但考虑为空
          sub_label_sequence.append(sub_label)
          sub_label_sequence.append(sub_label)

  else:
  #测试
    input_, label_ = preprocess(data[2], config.scale)#测试图片
    if len(input_.shape) == 3:
      h, w, _ = input_.shape
    else:
      h, w = input_.shape

    # Numbers of sub-images in height and width of image are needed to compute merge operation.
    nx = ny = 0 
    #自图需要进行合并操作
    for x in range(0, h-config.image_size+1, config.stride):#x从0到h-33+1 步长stride(21)
      nx += 1; ny = 0
      for y in range(0, w-config.image_size+1, config.stride):#y从0到w-33+1 步长stride(21)
        ny += 1
        sub_input = input_[x:x+config.image_size, y:y+config.image_size] # [33 x 33]
        sub_label = label_[x+int(padding):x+int(padding)+config.label_size, y+int(padding):y+int(padding)+config.label_size] # [21 x 21]
        
        sub_input = sub_input.reshape([config.image_size, config.image_size, 1])  
        sub_label = sub_label.reshape([config.label_size, config.label_size, 1])

        sub_input_sequence.append(sub_input)
        sub_label_sequence.append(sub_label)

  """
  len(sub_input_sequence) : the number of sub_input (33 x 33 x ch) in one image
  (sub_input_sequence[0]).shape : (33, 33, 1)
  """
  # Make list to numpy array. With this transform
  # 上面的部分和训练是一样的
  arrdata = np.asarray(sub_input_sequence) # [?, 33, 33, 1]
  arrlabel = np.asarray(sub_label_sequence) # [?, 21, 21, 1]

  make_data(sess, arrdata, arrlabel)

  if not config.is_train:#存成h5格式
    return nx, ny
    
def imsave(image, path):
  return scipy.misc.imsave(path, image)

def merge(images, size):
  h, w = images.shape[1], images.shape[2]#觉得下标应该是0,1
  img = np.zeros((h*size[0], w*size[1], 1))
  for idx, image in enumerate(images):
    i = idx % size[1]
    j = idx // size[1]
    img[j*h:j*h+h, i*w:i*w+w, :] = image

  return img

        utils.py说明了就是一个函数池,注意下面函数就可以:

  • prepare_data(sess,dataset):返回data,data是训练集或测试集中bmp格式的图像。
  • input_setup(sess,config):读取train set or test set ;做sub-images;保存成h5文件。
  • read_data(path):读取.h5文件的data和label数据,转化np.array格式。
  • preprocess(path,scale=3):(1)读取灰度图像;(2)modcrop;(3)归一化;(4)两次bicubic interpolation,返回input_ ,label_。即对路径下的image裁剪成scale整数倍,再对image缩小1/scale倍后,放大scale倍以得到低分辨率图input_,调整尺寸后的image为高分辨率图label_。
  • make_data(sess,data,label):将data保存为h5格式的数据,保存到指定路径,是通过create_dataset函数写入的。
  • imread(path,is_grayscale=True):读取指定路径的图像。
  • modcrop(image, scale=3) #把图像的长和宽都变成scale的倍数。
  • modcrop_small(image) #把result变为和origin一样的大小(需要自己写或参考其他)。
  • imsave(image,path):将scipy.misc.imsave封装到imsave供自己使用。
  • merge(image,size):合并分割后的图片。

到这里差不多,代码解读基本完成。相信你看完之后也可以自己完成运行测试啦!

  • 【6】最后,再附一个项目运行基本流程:
  • 准备数据集(训练集、测试集);
  • 训练模型
  • 利用模型测试数据
  • 模型评价

 

评论 15
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值