最近在学习SRCNN,阅读代码做好笔记
代码下载链接https://github.com/tegg89/SRCNN-Tensorflow
下面开始
"""
Scipy version > 0.18 is needed, due to 'mode' option from scipy.misc.imread function
"""
import os #导入os库,主要用于系统命令处理
import glob #导入glob库,作用是类似于系统的文件路径匹配查询
import h5py #h5py库,主要用于读取或创建datasets或groups
import random #随机数库,主要用于生成随机数
import matplotlib.pyplot as plt #导入matlpotlib.pyplot,画图,数据可视化
from PIL import Image # for loading images as YCbCr format
import scipy.misc #该库主要用于将数组保存成图像形式
import scipy.ndimage #该库用于图像处理
import numpy as np
import tensorflow as tf
try:
xrange
except:
xrange = range #处理异常中断
FLAGS = tf.app.flags.FLAGS #命令行参数传递
def read_data(path):
"""
Read h5 format data file
Args:
path: file path of desired file
data: '.h5' file format that contains train data values
label: '.h5' file format that contains train label values
读取h5格式的数据文件
参数:
路径:所需文件的文件路径
数据:”。包含训练数据值的h5'文件格式
标签:”。包含训练标签值的h5'文件格式
"""
with h5py.File(path, 'r') as hf:
data = np.array(hf.get('data'))
label = np.array(hf.get('label'))
return data, label
def preprocess(path, scale=3):#定义预处理函数
"""
Preprocess single image file
(1) Read original image as YCbCr format (and grayscale as default)
(2) Normalize
(3) Apply image file with bicubic interpolation
Args:
path: file path of desired file
input_: image applied bicubic interpolation (low-resolution)
label_: image with original resolution (high-resolution)
预处理单个图像文件
(1)读取原始图像为YCbCr格式(默认灰度)
(2)标准化
(3)应用双三次插值的图像文件
参数:
路径:所需文件的文件路径
input_:图像应用双三次插值(低分辨率)
标签_:原始分辨率图像(高分辨率)
"""
image = imread(path, is_grayscale=True)
label_ = modcrop(image, scale) #对图像进行缩放操作
# Must be normalized
image = image / 255.
label_ = label_ / 255. #归一化操作
input_ = scipy.ndimage.interpolation.zoom(label_, (1./scale), prefilter=False) #先把labei即清晰的原图变模糊,没有使用预滤波
input_ = scipy.ndimage.interpolation.zoom(input_, (scale/1.), prefilter=False) #再把变模糊的图像放大,没有使用预滤波
return input_, label_
def prepare_data(sess, dataset):
"""
Args:
dataset: choose train dataset or test dataset
For train dataset, output data would be ['.../t1.bmp', '.../t2.bmp', ..., '.../t99.bmp']
"""
if FLAGS.is_train:
filenames = os.listdir(dataset)
data_dir = os.path.join(os.getcwd(), dataset) #路径拼接
data = glob.glob(os.path.join(data_dir, "*.bmp")) #路径查询匹配,返回所有匹配的文件路径列表
else:
data_dir = os.path.join(os.sep, (os.path.join(os.getcwd(), dataset)), "Set5")
data = glob.glob(os.path.join(data_dir, "*.bmp"))
return data
def make_data(sess, data, label):
"""
Make input data as h5 file format
Depending on 'is_train' (flag value), savepath would be changed.
"""
if FLAGS.is_train:
savepath = os.path.join(os.getcwd(), 'checkpoint/train.h5')
else:
savepath = os.path.join(os.getcwd(), 'checkpoint/test.h5')
with h5py.File(savepath, 'w') as hf:
hf.create_dataset('data', data=data)
hf.create_dataset('label', data=label)
def imread(path, is_grayscale=True):
"""
Read image using its path.
Default value is gray-scale, and image is read by YCbCr format as the paper said.
使用图像的路径读取图像。
默认值为灰度,图像采用YCbCr格式读取。
"""
if is_grayscale:
return scipy.misc.imread(path, flatten=True, mode='YCbCr').astype(np.float)
else:
return scipy.misc.imread(path, mode='YCbCr').astype(np.float)
def modcrop(image, scale=3):
"""
To scale down and up the original image, first thing to do is to have no remainder while scaling operation.
We need to find modulo of height (and width) and scale factor.
Then, subtract the modulo from height (and width) of original image size.
There would be no remainder even after scaling operation.
要对原始图像进行缩放,首先要做的是在缩放操作时没有余数。
我们需要找到高度(和宽度)与比例因子的模。
然后,从原始图像大小的高度(和宽度)中减去模。
即使在缩放操作之后也不会有余数。
"""
if len(image.shape) == 3:
h, w, _ = image.shape
h = h - np.mod(h, scale) #mod的作用是取余数,这里返回的是h除以scale后得到的余数
w = w - np.mod(w, scale) #同上
image = image[0:h, 0:w, :]
else:
h, w = image.shape
h = h - np.mod(h, scale)
w = w - np.mod(w, scale)
image = image[0:h, 0:w]
return image
def input_setup(sess, config):
"""
Read image files and make their sub-images and saved them as a h5 file format.
读取图像文件并生成它们的子图像,并将它们保存为h5文件格式。
"""
# Load data path
if config.is_train:
data = prepare_data(sess, dataset="Train")
else:
data = prepare_data(sess, dataset="Test")
sub_input_sequence = []
sub_label_sequence = []
padding = abs(config.image_size - config.label_size) / 2 # 6
if config.is_train:
for i in range(len(data)):#一幅图作为一个data
input_, label_ = preprocess(data[i], config.scale)#得到data[]的LR和HR图input_和label_
if len(input_.shape) == 3:
h, w, _ = input_.shape
else:
h, w = input_.shape
#把input_和label_分割成若干自图sub_input和sub_label
for x in range(0, h-config.image_size+1, config.stride):
for y in range(0, w-config.image_size+1, config.stride):
sub_input = input_[x:x+config.image_size, y:y+config.image_size] # [33 x 33]
sub_label = label_[x+padding:x+padding+config.label_size, y+padding:y+padding+config.label_size] # [21 x 21]
sub_input = sub_input.reshape([config.image_size, config.image_size, 1])#按image size大小重排 因此 imgae_size应为33 而label_size应为21
sub_label = sub_label.reshape([config.label_size, config.label_size, 1])
sub_input_sequence.append(sub_input)#在sub_input_sequence末尾加sub_input中元素 但考虑为空
sub_label_sequence.append(sub_label)
else:
#测试
input_, label_ = preprocess(data[0], config.scale)#测试图片
if len(input_.shape) == 3:
h, w, _ = input_.shape
else:
h, w = input_.shape
nx = 0 #后注释
ny = 0 #后注释
#自图需要进行合并操作
for x in range(0, h-config.image_size+1, config.stride): #x从0到h-33+1 步长stride(21)
nx += 1
ny = 0
for y in range(0, w-config.image_size+1, config.stride):#y从0到w-33+1 步长stride(21)
ny += 1
#分块sub_input=input_[x:x+33,y:y+33] sub_label=label_[x+6,x+6+21, y+6,y+6+21]
sub_input = input_[x:x+config.image_size, y:y+config.image_size] # [33 x 33]
sub_label = label_[x+padding:x+padding+config.label_size, y+padding:y+padding+config.label_size] # [21 x 21]
sub_input = sub_input.reshape([config.image_size, config.image_size, 1])
sub_label = sub_label.reshape([config.label_size, config.label_size, 1])
sub_input_sequence.append(sub_input)
sub_label_sequence.append(sub_label)
# 上面的部分和训练是一样的
arrdata = np.asarray(sub_input_sequence) # [?, 33, 33, 1]
arrlabel = np.asarray(sub_label_sequence) # [?, 21, 21, 1]
make_data(sess, arrdata, arrlabel)#存成h5格式
if not config.is_train:
return nx, ny
def imsave(image, path):
return scipy.misc.imsave(path, image)
def merge(images, size):
h, w = images.shape[1], images.shape[2] #觉得下标应该是0,1
#h, w = images.shape[0], images.shape[1]
img = np.zeros((h*size[0], w*size[1], 1))
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
img[j*h:j*h+h, i*w:i*w+w, :] = image
return img