概述
TensorFlow读取数据,官网介绍的方法有3种:
- 预加载数据 (Preloaded data): 在Graph中定义常量或变量来保存数据。
- 供给数据 (Feeding): 在Graph运行中将Python代码产生好的数据供给TF后端。
- 从文件读取数据 (Reading from file): 在Graph的起始, 利用输入管线直接从文件中读取数据(最常用)。
看官网上这么写,还是不太清楚这几种方法到底是怎么实现的,于是查了些资料,稍作整理一番。
先理解一下TensorFlow的工作模式:
TF底层也就是计算核心模块和运行框架是用C++写的,同时提供API给Python (TF也提供了C++、Java、Go的API, 没用过, 不管),然后,Python调用这些API设计网络模型Graph,交给后端运算执行。所以Python负责Design,C++负责Run。
预加载数据
仅适用于可以完全加载到内存中的小数据集。有两种方法:
- 存储在常量中。
- 存储在变量中,且初始化后值不变。
数据存到常量中:
import tensorflow as tf
x_data = [2, 3, 4]
y_data = [4, 0, 1]
x = tf.constant(x_data)
y = tf.constant(y_data)
with tf.Session() as sess:
...
sess.run(x)
sess.run(y)
数据存到变量中,就需要在数据流图建立后初始化这个变量,而且值不能再被改变:(这里也用到了占位符,但重点是使用了变量存储数据)
import tensorflow as tf
x_data = [2, 3, 4]
y_data = [4, 0, 1]
x_initializer = tf.placeholder(dtype=x.dtype,shape=x.shape)
y_initializer = tf.placeholder(dtype=y.dtype,shape=y.shape)
x = tf.Variable(x_initializer,trainable=False,collections=[])
y = tf.Variable(y_initializer,trainable=False,collections=[])
with tf.Session() as sess:
...
sess.run(x.initializer, feed_dict={x_initializer: x_data})
sess.run(y.initializer, feed_dict={y_initializer: y_data})
设置 trainable=False 可以防止该变量被数据流图的 GraphKeys.TRAINABLE_VARIABLES 收集, 这样在训练的时候变量就不会和其他网络参数一样被更新; 设置 collections=[] 可以防止被 GraphKeys.VARIABLES 收集做为保存和恢复的中断点。
供给数据
TensorFlow有数据供给机制,允许在Graph中将数据注入到任一张量。python代码产生的数据可以通过此方式直接输入到Graph。
设计placeholder节点的唯一意图就是为了提供数据供给(feeding)的方法。placeholder节点被声明的时候是未初始化的,不包含数据,需要通过run()或者eval()函数输入feed_dict
参数, 才能启动运算。
import tensorflow as tf
x1 = tf.placeholder(tf.int16)
x2 = tf.placeholder(tf.int16)
y = tf.add(x1, x2)
#python产生数据
data1 = [2, 3, 4]
data2 = [4, 0, 1]
with tf.Session() as sess:
sess.run(y, feed_dict={x1:data1, x2:data2})
#或者用eval()
#with tf.Session():
# y.eval(feed_dict={x1:data1, x2:data2})
从文件读取数据
根据文件格式, 选择对应的文件阅读器, 然后将文件名队列提供给阅读器的read方法。read输出的key表征输入的文件和纪录,而字符串标量value可以被不同的解析器解码成张量样本。它就是我们读到的数据。
从csv文件读取数据
从CSV文件中读取数据, 需要使用 TextLineReader
和 decode_csv
,使用一个reader的写法如下:
# -*- coding:utf-8 -*-
import tensorflow as tf
#生成文件名队列
filenames = ['num1.csv', 'num2.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=True)
#定义阅读器
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
#定义解码器,一次读一行
example, label = tf.decode_csv(value, record_defaults=[[1], [1]])
#使用tf.train.batch()相当于多加了一个样本队列和一个QueueRunner
#example, label = tf.train.batch([example,label],batch_size=3)
#example, label = tf.train.shuffle_batch([example,label],batch_size=3,capacity=100,min_after_dequeue=10)
with tf.Session() as sess:
coord = tf.train.Coordinator() #创建一个协调器,管理线程
threads = tf.train.start_queue_runners(coord=coord)
for i in range(10):
example_batch, label_batch = sess.run([example, label])
print(example_batch, label_batch)
coord.request_stop()
coord.join(threads)
上面的写法是一个reader,一个样本。
如果加上tf.train.batch可以实现一个reader,batch_size个样本。
如果用tf.train.shuffle_batch的话,也可以读batch_size个样本,并且打乱顺序。
使用多个reader的写法如下:
# -*- coding:utf-8 -*-
import tensorflow as tf
filenames = ['num1.csv', 'num2.csv']
filename_queue = tf.train.string_input_producer(filenames, shuffle=True)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
#定义了多个解码器,每个解码器跟一个reader相连,这里reader设置为2
example_list = [tf.decode_csv(value, record_defaults=[[1], [1]]) for _ in range(2)]
#使用tf.train.batch_join(),可以使用多个reader,并行读取数据,每个Reader使用一个线程
example, label = tf.train.batch_join(example_list, batch_size=3)
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(10):
example_batch, label_batch = sess.run([example, label])
print(example_batch, label_batch)
coord.request_stop()
coord.join(threads)
tf.train.batch与tf.train.shuffle_batch函数是单个Reader读取,可以多线程(即batch_size>1)。tf.train.batch_join与tf.train.shuffle_batch_join 可以设置多Reader读取,每个Reader使用一个线程。至于两种方法的效率,单Reader时,2个线程就达到了速度的极限。多Reader时,2个Reader就达到了极限。所以并不是线程越多越快,太多的线程会使效率下降。
tf.decode_csv()中的 record_defaults = [[1], [1]]:record_defaults是解析的模板,每行有几列单元就有几个[1];整型数值解析标准是[1],浮点型是[1.0],字符型是['null']。
从图像文件读取数据
首要目标是获得图像名列表。
可以把图像文件路径存到xlsx或txt文件中,一行一个样本,再用python方法读取文件名列表。
可以用tf.gfile直接获取图像文件夹内所有文件名。(下方代码是这种)
# -*- coding:utf-8 -*-
import tensorflow as tf
import os.path
filenames = tf.gfile.ListDirectory('image_dir')
filenames = [os.path.join('image_dir', f) for f in filenames] #文件完整路径
filename_queue = tf.train.string_input_producer(filenames, shuffle=True)
#定义阅读器
reader = tf.WholeFileReader()
key, value = reader.read(filename_queue)
#定义解码器
image = tf.image.decode_jpeg(value, channels=3)
image = tf.reshape(image, [image_size, image_size, 3])
image_batch = tf.train_batch(image, batch_size=3)
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(10):
example_batch = sess.run(image_batch)
print(example_batch.shape)
coord.request_stop()
coord.join(threads)
从TFRecords文件读取数据
这种方式先要将你的数据转换为tensorflow标准格式TFRecords文件,它实际上是一种二进制文件,虽然不好理解,但能更好的利用内存,更容易与TF网络架构匹配。
TFRecords文件包含了 tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。你可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串,并且通过 tf.python_io.TFRecordWriter 类写入到TFRecords文件。
从TFRecords文件中读取数据, 可以使用 tf.TFRecordReader 的 tf.parse_single_example 解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量。
import os
import tensorflow as tf
from PIL import Image
#classes是根据数据类型自定义的列表
#比如我把所有图像分类存放在class_0、class_1、class_2文件夹里
#那么classes=['class_0','class_1','class_2']
writer = tf.python_io.TFRecordWriter("train.tfrecords")
for index, name in enumerate(classes):
class_path = dataset_dir + name + "\"
#举个例子则class_path可能是“E:\image_dataset\class_0\”
for img_name in os.listdir(class_path):
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((224, 224))
img_raw = img.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={"label":tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))}))
writer.write(example.SerializeToString()) #序列化为字符串
writer.close()
于是数据相关的信息都被存到了一个文件中,包括example和label。
生成了TFRecords文件后,再使用队列读取数据,代码如下:
#生成文件名队列
filename = "train.tfrecords"
filename_queue = tf.train.string_input_producer([filename])
#定义阅读器
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
#返回文件名和文件
features = tf.parse_single_example(serialized_example, features={'label': tf.FixedLenFeature([], tf.int64), 'img_raw' : tf.FixedLenFeature([], tf.string)})
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [224, 224, 3])
image = tf.cast(image, tf.float32)*(1./255)
label = tf.cast(features['label'], tf.int32)
images, labels = tf.train.shuffle_batch([image, label], batch_size=30, capacity=2000, min_after_dequeue=1000)
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
threads = tf.train.start_queue_runners(sess=sess)
for i in range(10):
image_batch, label_batch= sess.run([images, labels])
print(image_batch.shape, label_batch)
* 确实比前两种方法麻烦,但是既然是官方标准格式,它总有自己的好处。
* 因为TF的graph能够记住状态(state),就是说TFRecordReader能够记住tfrecord的位置,这样才能不断返回下一个文件。因此在使用之前,必须初始化整个graph,tf.initialize_all_variables()的作用就是初始化。
* sess.run()时队列才执行,TFRecordReader会不断弹出队里中文件名,直到队列为空。