这篇开始,讲述自己对于tensorflow文档中,利用CNN建立CIFAR-10模型的理解,如有错误欢迎指正,也是互相学习。由于代码太长,所以分几篇来讲述。
第一篇是关于cifar10_input.py文件。
cifar10_input.py
这个文件主要用来进行数据读取以及输入数据处理。
Tensorflow一共有三种读取数据的方式:
第一种最简单的预加载数据,直接在graph中定义常量和变量来保存数据(仅适用于数据小的情况,神经网络自然这种方法不行)
第二种供给数据(feeding),那么这种方法前面出现过很多次,也就是在图中建立placeholder,然后在跑图的时候再对占位进行数据feed。(如果数据量过大,一次性读入所有数据,再分批次feed进图,也会占用太多的内存空间)
第三种从文件中读取数据,在Tensorflow的起始,用一个输入管线从文件中读取数据。
那么这个模型就是运用了第三种数据读取方式。
这里我们按照实际建立图的顺序来进行讲解,由cifar10_train.py文件我们知道先调用的是distorted_input函数。
distorted_input
函数输入为data_dir(训练集所在文件夹),以及batch_size。
函数输出为图像组成的4维tensor,以及labels组成的1维tensor。
下面为数据读取步骤:
- 生成文件名列表,也就是把需要输入的所有文件的文件名放到一个列表里。
- 将文件名列表输入到tf.train.string_input_producer()函数中,生成一个先入先出的文件名队列,同时将一个QueueRunner添加到整个图的QUEUE_RUNNER当中。(tf.train.QueueRunner本质上是tensorflow的一个类,用来完成队列的一系列入队操作)
下面为API文档,具体请参考文档。
https://www.tensorflow.org/api_docs/python/tf/train/string_input_producer - 利用文件阅读器读取文件名队列中文件里的数据。接下来我们跟随代码转移到read_cifar10函数中。
- 从read_cifar10函数回来,这个时候read_input已经是刚刚输出的结构体了,它有一个样本的所有信息。然后对图像数据进行一系列预处理,这是data augmentation,包括随机裁取,随机左右翻转,随机亮度调节,随机对比度调节,图像归一化。
所有关于图像操作的API文档
https://www.tensorflow.org/api_guides/python/image - 定义参数min_queue_examples,然后进入到generate_image_and_label_batch()函数中。
def distorted_inputs(data_dir, batch_size):
# 生成文件名列表
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
for i in range(1, 6)]
for f in filenames:
if not gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
# 生成文件名队列
filename_queue