文章目录
用户处理数据集的典型流程是:
- 输入数据集从文件系统读取到内存中;
- 将其转换为模型所需要的格式;
- 以某种形式传入到数据流图中,开始模型训练。
一般采用的数据读取方式有两种:
- Large-scale Dataset:一般由由大量数据文件组成。因为数据规模太大,所以无法一次性全部加载到内存。但是,如果每进行一步模型训练就加载一次所需的batch data,这将阻塞模型的训练过程。为了减小数据读取对模型训练效率的影响,常用的方法是通过多线程并行读取数据。
- Tiny Dataset:由较少数据文件组成,能够在数据模型开始前一次性将所有data load到memory中。
数据并行读取
large scale dataset一般无法一次性加载到内存中进行处理,如ImageNet数据集。当处理如此规模的数据集时,TF提供了以输入流水线方式从多个文件中并行读取数据的方法,这使得模型训练时能够有充足的数据能够feed to graph。
它的主要步骤如下:
- 创建文件名列表;
- 创建文件名队列;
- 创建Reader & Decoder;
- 创建样例队列;
下图展示了完整的并行数据读取流水线:
上图样例队列和批样例队列中的元素名,指示哪个文件中的第几个数据记录(即第几行)。
理解并行流水线数据读取过程的关键是掌握文件名队列和样例队列。文件名队列为读取数据文件提供了一个缓冲区,样例队列为数据feed to graph提供了缓冲区。
创建文件名列表
文件名列表是指输入数据集中所有文件的名称所构成的列表。列表中的元素可能是在本地文件系统上的文件位置,也可能是共享文件系统或分布式文件系统上的统一资源标识符(URI)。
两种创建文件名列表的方法:
- python list:如果文件名的个数不多,或者文件命名遵循rules,那么用户可以直接使用list存储文件名。
tf.train.match_filenames_once()
:该方法在graph创建了一个获取文件名列表的操作,它的输入是文件名列表的匹配模式,输出是一个存储了符合该匹配模式的文件名列表variable。在初始化全局变量是,该文件名列变量也会被初始化。
创建文件名队列
一般使用tf.train.string_input_producer()
创建文件名队列,它的输入是之前创建的文件名列表,输出是一个先入先出的queue。
epoch:完整遍历一次输入数据集即为一个epoch。
训练模型需要反复遍历整个输入数据集,以不断更新模型参数。
用户可以通过tf.train.string_input_producer()
的输入参数num_epoches设置模型的最大训练周期数。但每次新的epoch,我们希望模型数据顺序是变化的,以防止模型过拟合。因此,可以设置tf.train.string_input_producer()
的输入参数 shuffle = True,此时程序就可以打乱每个epoch内的文件名顺序。
tf.train.string_input_producer()
所有输入参数:
- string_tensor:存储文件名列表的字符串张量
- num_epochs:最大训练周期
- shuffle:是否打乱文件顺序
- seed:随机化种子,用于文件打乱
- capacity:filename queue的容量(长度
- shareed_name:多个sessions见共享的文件名队列
- name:创建文件名队列操作的名称
- cancel_op:取消队列操作
创建Reader & Decoder
Reader的功能就是读取数据,Decoder的功能是将数据转换为张量格式。两者都与数据文件格式有关。下表给出了TF推荐的三种数据文件格式及其对应Reader&Decoder。
文件格式 | Reader类型 | Decoder类型 |
---|---|---|
CSV file | tf.TextLineReader() | tf.decode_csv |
TFRecords file | tf.TFRecordReader() | tf.parse_single_example |
自由格式文件 | tf.FixedLengthRecordReader() | tf.decode_raw |
一般流程:首先创建数据文件对应的Reader,然后从文件名队列中取出文件名,并传入Reader.read()方法,最后使用对应的Decoder将数据记录中的每一列数据都转换为张量格式。
CSV file
字符分隔值(Comma-Seperated Values, CSV)文件是以纯文本形式存储表格数据。CSV的一般标准是:
CSV由多条数据记录组成,数据记录之间以某种换行符进行分隔。每条记录由多个字段组成,字段间通常以制表符或逗号分隔。所有记录拥有相同的字段序列格式。
以读取多个记录收入支出表(file1.csv & file2.csv)为例,展示TF如何读取CSV file。其中部分表格内容如下所示:
year | month | income | outgo |
---|---|---|---|
2021 | 1 | 40000 | 20000 |
2021 | 2 | 42000 | 19000 |
每条数据记录包含四个字段:year, month, income, outgo。示例代码如下:
# create filename queue
filename_list = ['file1.csv', 'file2.csv']
filename_queue = tf.train.string_input_producer(filename_list)
# create Reader for csv file
reader = tf.TextLineReader()
# read one row from csv file
_, value = reader.read(filename_queue)
# setting default value
record_defaults = [[2021], [0], [0.0], [0.0]]
# transfer data to Tensor with tf.decode_csv
year, month, income, outgo = tf.decode_csv(value, record_defaults)
features = tf.stack([year, month, income, outgo])
注,Reader.read()
方法只能读取一行数据记录,Reader.read_up_to()
方法能够一次读取多条数据,通过设置其num_records
参数,可以显式地制定一次读取的数据记录数量。
tf.decode_csv()
方法中的record_defaults
参数,是为了给数据记录中的某些不合法或不存在的字段填充默认值,以确保程序正常执行。注意,以上的Reader.read()
方法还有tf.decode_csv()
方法返回的都是graph ops,而不是real data,用户需要通过session.run()
才能获得data。
TFRecords file
TFRecords文件存储的是有结构的序列化字符块,他是TF推荐的standard file format。但是一般数据集的annotations都没有采用该类数据格式,因此我们在这里不做更多介绍。
Any format file
自由格式文件是用户自定义的二进制文件。它的存储对象是字符串,每条数据记录都是一个固定长度的字节块。因此如果要想正确识别和转换二进制文件中的数据记录,必须使用tf.FixedLengthRecordReader()
读取二进制文件中固定长度的字节块,然后使用tf.decode_raw()
方法将读取的字符串转换为张量。tf.FixedLengthRecordReader()
和tf.TextLineReader()
均继承自ReaderBase类,都支持一次读取多条记录的方法。
tf.decode_raw()方法的功能是将字符串转换为张量,其prototype如下所示:
tf.decode_raw(bytes, out_type, little_endian=None, name=None)
创建样例队列
在“CSV file”小结,我们得到了year, month, income, outgo四个特征张量。在会话执行时,为了使计算机任务顺利获得输入数据,我们需要使用tf.train.start_queue_runners()方法启动执行入队操作的所有线程,具体包括文件名入队到filename_queue的操作,样例入队到样例队列的操作。
示例代码如下:
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
# 启动所有执行入队操作的后台线程
tf.train.start_queue_runners(sess=sess)
for i in range(2):
example = sess.run(features)
print(example)
上述代码中的features是之前创建的数据读取、解析的操作。同时上述代码并不适合生产环境,因为其较差的容错性:无人管理队列操作后台线程的生命周期,任何线程出现异常都会导致程序崩溃。为了解决该问题,可以使用tf.train.Coordinator()
方法构建管理多线程生命周期的协调器。它会监控TF所有后台线程,但其中某个线程出现异常时,Coordinator.should_stop()
将返回True,使for循环结束。然后执行finally中Coordinator.request_stop()
方法,请求所有线程安全退出。
需要注意的是,使用Coordinator管理multi-threads之前,需要先执行tf.local_variables_initializer()
方法对其进行初始化。所以使用tf.group()
方法将tf.local_variables_initializer()
和tf.global_variables_initializer()
聚合生成整个程序的初始化操作init_op。
示例代码如下:
import tensorflow as tf
# create filename_queue and setting epochs = 5
filename_list = ['file1.csv', 'file2.csv']
filename_queue = tf.train.string_input_producer(filename_list , num_epochs=5)
...
# aggregate local and global initialization ops
init_op = tf.group(tf.local_variables_initializer(),
tf.global_variables_initializer())
with tf.Session() as sess:
init_op.run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord=coord)
print("=> Threads: ", threads)
try:
for i in range(10):
if not coord.should_stop():
example = sess.run(features)
print(example)
except tf.errors.OutofRangeError:
print('=> Catch OutofRangeError')
finally:
# request to stop all the threads in background
coord.request_stop()
print('=> Finish reading ...')
coord.join(threads)
创建批样例数据
通过上节内容,我们成功获得了数据样例,但是需要将这些样例聚合成批数据才能用于模型训练、评估和推理使用。TF提供的tf.train.shuffle_batch()
方法不仅能够使用样例创建批数据,而且能够在打包过程中打乱样例顺序。增加随机性。
示例代码:
filename_queue = ...
examples = ...
# batch queue settings
batch_size = 16
min_after_dequeque = 10000 # 样例队列中出队的样例个数
capacity = min_after_dequeque + 3 * batch_size # 批数据队列容量
# create batch queue
batch_queue = tf.train.shuffle_batch([examples],
batch_size=batch_size,
capacity=capacity,
min_after_dequeque=min_after_deque)
tf.train.shuffle_batch()
除了上面使用的参数外,常用的还有设置线程个数的num_threads参数,设置随机化种子的seed参数,以及设置多条样例入队的enqueue_many参数。
填充数据节点
使用批数据训练的模型基本上都是用feed数据节点的方法,他不需要读取完整的数据集,有效减少了内存开销。同时,基于并行输入流水线的数据读取方法保证了实时性,与将全部数据预加载到内存中,对训练结果没有明显差距。
CIFAR-10数据集示例
CIFAR-10数据集总共包含60000张32x32x3的图像,图片总共有10类,每一类6000张图片, 下载地址。整个数据集被分为6个批数据,每一批数据包含1W张图片,其中5W张用于模型训练,1W张用于模型测试。
在CIFAR-10数据集中,一条数据记录由类别标签和图像数据两部分组成。单张图片需要3072个字节,类别标签1个字节。因此,CIFAR-10数据集中单条记录占用3073个字节,他们以二进制数据文件格式存储。
示例代码:
import tensorflow as tf
# label length
LABEL_BYTES = 1
# image size
IMAGE_SIZE = 32
# image channel
IMAGE_CHANNEL = 3
# image data length
IMAGE_BYTES = IMAGE_SIZE * IMAGE_SIZE * IMAGE_CHANNEL
# classes num
NUM_CLASSES = 10
def read_cifar10(data_file, batch_size):
"""
input params:
data_file: CIFAR-10 data file
batch_size: the size of batch
returns:
images: images batch following format [batch_size, IMAGE_SIZE, IMAGE_SIZE]
labels: labels batch following format [batch_size, NUM_CLASSES]
"""
record_bytes = LABEL_BYTES + IMAGE_BYTES
# create filename list
data_files = tf.gfile.Glob(data_file)