Tensorflow 程序读取数据一共有3种方法:
- 供给数据(feeding):在程序运行的每一步,让Python代码来供给数据
- 从文件读取数据: 让一个输入管线从文件中读取数据
- 预加载数据:在tensorflow图中定义常量或变量来保存所有数据(适用于数据量小的时候)
一个典型的文件读取管线会包含下面这些步骤:
- 文件名列表
- 可配置的 文件名乱序(shuffling)
- 可配置的 最大训练迭代数(epoch limit)
- 文件名队列
- 针对输入文件格式的阅读器
- 纪录解析器
- 可配置的预处理器
- 样本队列
1.得到文件名列表
filenames=[os.path.join(data_dir,'data_batch_%d.bin'%i) for i in range(1,6)] #得到一个文件名列表
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: '+ f)
此处用list表示文件名列表,然后依次检验文件是否存在,以抛出异常
2.将文件名列表交给tf.train.string_input_producer函数,得到一个先入先出的队列(Queue),文件阅读器会需要它来读取数据
其中可配置参数中有shuffle,是bool值,判断要不要用乱序操作
filename_queue=tf.train.string_input_producer(filenames)#生成一个先入先出队列,需要用文件阅读器来读取其数据
3.得到文件名队列后,针对输入文件格式,创建阅读器进行读取
例如:若从CSV文件中读取数据,需要使用TextLineReader和decode_csv来进行读取和解码
若是CIFAR-10 dataset文件,因为每条记录的长度固定,一个字节的标签+3072像素数据
所以此处采用FixedLengthRecordReader()和decode_raw来进行读取和解码
每次read的执行都会从文件中读取一行内容, decode_csv 操作会解析这一行内容并将其转为张量列表。如果输入的参数