这次开始读一个基于tensorflow的resnet的完整实现的源码,整个项目一共包含了三个文件cifar_input.py、resnet_main.py、resnet_model.py。这次先针对cifar_input.py文件进行源码的解析,该文件主要是实现数据集的预处理和读入的功能,具体代码先贴在下面。
关于cifar数据集的介绍以及如何将二进制数据输出为jpg格式可以参考博客【1】
在代码中使用tf.RandomShuffleQueue具体的介绍可以参考博客
【2】
在代码中使用tf.FIFOQueue具体的介绍可以参考博客【3】
对于tensorflow中批量输入数据的实现可以参考博客【4】
import tensorflow as tf
def build_input(dataset, data_path, batch_size, mode):
"""Build CIFAR image and labels.
Args:
dataset(数据集): Either 'cifar10' or 'cifar100'.
data_path(数据集路径): Filename for data.
batch_size: Input batch size.
mode(模式): Either 'train' or 'eval'.
Returns:
images(图片): Batches of images. [batch_size, image_size, image_size, 3]
labels(类别标签): Batches of labels. [batch_size, num_classes]
Raises:
ValueError: when the specified dataset is not supported.
"""
# 数据集参数
# 由于数据是二进制文件读入的,数据集中标签和图片数据在读入后是存放在一个一维
# 列表中的,所以label_offset表示了二进制文件中表示标签字节起始位置的偏移量
# label_bytes代表了标签在二进制文件中所占的字节数。
image_size = 32
if dataset == 'cifar10':
label_bytes = 1
label_offset = 0
num_classes = 10
elif dataset == 'cifar100':
label_bytes = 1
label_offset = 1
num_classes = 100
else:
raise ValueError('Not supported dataset %s', dataset)
# 数据读取参数
depth = 3
# image_bytes代表了一张图片所占的字节数量
image_bytes = image_size * image_size * depth
# record_bytes代表了一个二进制文件内标签加上图片本身数据总体的字节数量
record_bytes = label_bytes + label_offset + image_bytes
# 获取文件名列表
# data_path是一个路径下的文件名或者符合正则pattern的文件,data_file就是在
# data_path中所匹配的所有文件所构成的列表
data_files = tf.gfile.Glob(data_path)
# 文件名列表生成器
file_queue = tf.train.string_input_producer(data_files, shuffle=True)
# 文件名列表里读取原始二进制数据
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
_, value = reader.read(file_queue)
# 将原始二进制数据转换成图片数据及类别标签
record = tf.reshape(tf.decode_raw(value, tf.uint8), [record_bytes])
label = tf.cast(tf.slice(record, [label_offset], [label_bytes]), tf.int32)
# 将数据串 [depth * height * width] 转换成矩阵 [depth, height, width].
depth_major = tf.reshape(tf.slice(record, [label_bytes], [image_bytes]),
[depth, image_size, image_size])
# 转换维数:[depth, height, width]转成[height, width, depth].
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
if mode == 'train':
# 增减图片尺寸
image = tf.image.resize_image_with_crop_or_pad(
image, image_size+4, image_size+4)
# 随机裁剪图片
image = tf.random_crop(image, [image_size, image_size, 3])
# 随机水平翻转图片
image = tf.image.random_flip_left_right(image)
# 逐图片做像素值中心化(减均值)
image = tf.image.per_image_standardization(image)
# 建立输入数据队列(随机洗牌)
example_queue = tf.RandomShuffleQueue(
# 队列容量
capacity=16 * batch_size,
# 队列数据的最小容许量
min_after_dequeue=8 * batch_size,
dtypes=[tf.float32, tf.int32],
# 图片数据尺寸,标签尺寸
shapes=[[image_size, image_size, depth], [1]])
# 读线程的数量
num_threads = 16
else:
# 获取测试图片,并做像素值中心化
image = tf.image.resize_image_with_crop_or_pad(
image, image_size, image_size)
image = tf.image.per_image_standardization(image)
# 建立输入数据队列(先入先出队列)
example_queue = tf.FIFOQueue(
3 * batch_size,
dtypes=[tf.float32, tf.int32],
shapes=[[image_size, image_size, depth], [1]])
# 读线程的数量
num_threads = 1
# 数据入队操作
example_enqueue_op = example_queue.enqueue([image, label])
# 队列执行器
tf.train.add_queue_runner(tf.train.queue_runner.QueueRunner(
example_queue, [example_enqueue_op] * num_threads))
# 数据出队操作,从队列读取Batch数据
images, labels = example_queue.dequeue_many(batch_size)
# 将标签数据由稀疏格式转换成稠密格式
# [ 2, [[0,1,0,0,0]
# 4, [0,0,0,1,0]
# 3, --> [0,0,1,0,0]
# 5, [0,0,0,0,1]
# 1 ] [1,0,0,0,0]]
labels = tf.reshape(labels, [batch_size, 1])
indices = tf.reshape(tf.range(0, batch_size, 1), [batch_size, 1])
labels = tf.sparse_to_dense(
tf.concat(values=[indices, labels], axis=1),
[batch_size, num_classes], 1.0, 0.0)
#检测数据维度
assert len(images.get_shape()) == 4
assert images.get_shape()[0] == batch_size
assert images.get_shape()[-1] == 3
assert len(labels.get_shape()) == 2
assert labels.get_shape()[0] == batch_size
assert labels.get_shape()[1] == num_classes
# 添加图片总结
tf.summary.image('images', images)
return images, labels
(1)二进制文件的读入操作
从源码中可以发现在读入二进制文件的原始信息时调用了tf.gfile.Glob、tf.train.string_input_producer、tf.FixedLengthRecordReader、reader.read等API。下面就针对这几个函数进行详细的介绍:
tf.gfile.Glob()
函数功能:查找匹配pattern的文件并以列表的形式返回,filename可以是一个具体的文件名,也可以是包含通配符的正则表达式。因此该函数返回的是所有满足匹配规则的文件名所构成的列表。
样例: 在样例代码中我首先设计了一个匹配规则data_path,之后调用tf.gfile.Glob(data_path)来获取文件列表。
data_path = r"D:\SourceInsightData\ResNet_TF\test*.txt"
data_files = tf.gfile.Glob(data_path)
print(data_files)
最终在控制台的输出结果为:
['D:\\SourceInsightData\\ResNet_TF\\test.txt', 'D:\\SourceInsightData\\ResNet_TF\\test1.txt']
tf.train.string_input_producer()
函数功能: 输出字符串到一个输入管道队列。
主要参数定义:
1、string_tensor: 1-D字符串Tensor
2、num_epochs=None: 一个整数(可选)。如果指定,string_input_producer在产生OutOfRange错误之前从string_tensor中产生num_epochs次字符串。如果未指定,则可以无限次循环遍历字符串。需要注意的是:当指定该参数为某一整数时,需要在调用tf.run之前初始化局部变量sess.run(tf.local_variables_initializer()),否则会报错。
3、shuffle=True: 当设置为True时文件队列的存放顺序会被随机打乱。
4、seed=None: 一个整数(可选)。如果shuffle==True,则使用种子。
5、capacity=32: 一个整数。设置队列容量。
tf.FixedLengthRecordReader()
函数功能: 创建一个可从文件名列表中读取内容的对象
在源码中可以发现该函数包含一个参数record_bytes,这就表示reader对象在每次调用read函数时都只能从相应的文件列表内读取固定大小的字节数量。
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
_, value = reader.read(file_queue)
样例: 在样例中每次从文件名队列内的文件读取7个字节的数据
data_path = r"D:\SourceInsightData\ResNet_TF\test*.txt"
data_files = tf.gfile.Glob(data_path)
file_queue = tf.train.string_input_producer(data_files, num_epochs=3, shuffle=True)
reader = tf.FixedLengthRecordReader(record_bytes=7)
key, value = reader.read(file_queue)
sess = tf.InteractiveSession()
init = tf.local_variables_initializer()
tf.train.start_queue_runners(sess=sess)
sess.run(init)
print(sess.run([key, value]))
print(sess.run([key, value]))
最终在控制台的输出结果为:
[b'D:\\SourceInsightData\\ResNet_TF\\test.txt:0', b'1234567']
[b'D:\\SourceInsightData\\ResNet_TF\\test1.txt:0', b'abcdefg']
(2)数据增强和预处理
在源码中调用了tf.image.resize_image_with_crop_or_pad、tf.random_crop、tf.image.random_flip_left_right三个API对原始数据进行了数据增强,调用了tf.image.per_image_standardization对图像数据进行了去中心话的操作。
tf.image.resize_image_with_crop_or_pad()
函数功能: 对图片进行裁剪和填充,当原图片尺寸大于目标图片
时,自动截取原图片居中位置。
从源码中可以知道函数包含三个参数,Image表示输入的图像数据,后两个参数分别表示所需要裁剪的最终图像的大小。
image = tf.image.resize_image_with_crop_or_pad(image,
image_size+4, image_size+4)
tf.random_crop()
函数功能: 随机裁剪图片。
从源码中可以发现该函数包含两个参数,image表示输入的图像数据,后一个参数表示随机裁剪的图片的shape。
image = tf.random_crop(image, [image_size, image_size, 3])
tf.image.random_flip_left_right()
函数功能: 随机水平翻转图片。
tf.image.per_image_standardization()
函数功能: 逐图片做像素值中心化。