《深入了解TensorFlow》笔记——Chapter 4.1 输入数据集


用户处理数据集的典型流程是:

  1. 输入数据集从文件系统读取到内存中;
  2. 将其转换为模型所需要的格式;
  3. 以某种形式传入到数据流图中,开始模型训练。

一般采用的数据读取方式有两种:

  • Large-scale Dataset:一般由由大量数据文件组成。因为数据规模太大,所以无法一次性全部加载到内存。但是,如果每进行一步模型训练就加载一次所需的batch data,这将阻塞模型的训练过程。为了减小数据读取对模型训练效率的影响,常用的方法是通过多线程并行读取数据。
  • Tiny Dataset:由较少数据文件组成,能够在数据模型开始前一次性将所有data load到memory中。

数据并行读取

large scale dataset一般无法一次性加载到内存中进行处理,如ImageNet数据集。当处理如此规模的数据集时,TF提供了以输入流水线方式从多个文件中并行读取数据的方法,这使得模型训练时能够有充足的数据能够feed to graph。
它的主要步骤如下:

  1. 创建文件名列表;
  2. 创建文件名队列;
  3. 创建Reader & Decoder;
  4. 创建样例队列;

下图展示了完整的并行数据读取流水线:

TF数据并行读取流水线上图样例队列和批样例队列中的元素名,指示哪个文件中的第几个数据记录(即第几行)。

理解并行流水线数据读取过程的关键是掌握文件名队列和样例队列。文件名队列为读取数据文件提供了一个缓冲区,样例队列为数据feed to graph提供了缓冲区。

创建文件名列表

文件名列表是指输入数据集中所有文件的名称所构成的列表。列表中的元素可能是在本地文件系统上的文件位置,也可能是共享文件系统或分布式文件系统上的统一资源标识符(URI)。

两种创建文件名列表的方法:

  • python list:如果文件名的个数不多,或者文件命名遵循rules,那么用户可以直接使用list存储文件名。
  • tf.train.match_filenames_once():该方法在graph创建了一个获取文件名列表的操作,它的输入是文件名列表的匹配模式,输出是一个存储了符合该匹配模式的文件名列表variable。在初始化全局变量是,该文件名列变量也会被初始化。

创建文件名队列

一般使用tf.train.string_input_producer()创建文件名队列,它的输入是之前创建的文件名列表,输出是一个先入先出的queue。

epoch:完整遍历一次输入数据集即为一个epoch。
训练模型需要反复遍历整个输入数据集,以不断更新模型参数。

用户可以通过tf.train.string_input_producer()的输入参数num_epoches设置模型的最大训练周期数。但每次新的epoch,我们希望模型数据顺序是变化的,以防止模型过拟合。因此,可以设置tf.train.string_input_producer()的输入参数 shuffle = True,此时程序就可以打乱每个epoch内的文件名顺序。

tf.train.string_input_producer()所有输入参数:

  • string_tensor:存储文件名列表的字符串张量
  • num_epochs:最大训练周期
  • shuffle:是否打乱文件顺序
  • seed:随机化种子,用于文件打乱
  • capacity:filename queue的容量(长度
  • shareed_name:多个sessions见共享的文件名队列
  • name:创建文件名队列操作的名称
  • cancel_op:取消队列操作

创建Reader & Decoder

Reader的功能就是读取数据,Decoder的功能是将数据转换为张量格式。两者都与数据文件格式有关。下表给出了TF推荐的三种数据文件格式及其对应Reader&Decoder。

文件格式Reader类型Decoder类型
CSV filetf.TextLineReader()tf.decode_csv
TFRecords filetf.TFRecordReader()tf.parse_single_example
自由格式文件tf.FixedLengthRecordReader()tf.decode_raw

一般流程:首先创建数据文件对应的Reader,然后从文件名队列中取出文件名,并传入Reader.read()方法,最后使用对应的Decoder将数据记录中的每一列数据都转换为张量格式。

CSV file

字符分隔值(Comma-Seperated Values, CSV)文件是以纯文本形式存储表格数据。CSV的一般标准是:

CSV由多条数据记录组成,数据记录之间以某种换行符进行分隔。每条记录由多个字段组成,字段间通常以制表符或逗号分隔。所有记录拥有相同的字段序列格式。

以读取多个记录收入支出表(file1.csv & file2.csv)为例,展示TF如何读取CSV file。其中部分表格内容如下所示:

yearmonthincomeoutgo
202114000020000
202124200019000

每条数据记录包含四个字段:year, month, income, outgo。示例代码如下:

# create filename queue
filename_list = ['file1.csv', 'file2.csv']
filename_queue = tf.train.string_input_producer(filename_list)
# create Reader for csv file
reader = tf.TextLineReader()
# read one row from csv file
_, value = reader.read(filename_queue)
# setting default value 
record_defaults = [[2021], [0], [0.0], [0.0]]
# transfer data to Tensor with tf.decode_csv
year, month, income, outgo = tf.decode_csv(value, record_defaults)

features = tf.stack([year, month, income, outgo])

注,Reader.read()方法只能读取一行数据记录,Reader.read_up_to()方法能够一次读取多条数据,通过设置其num_records参数,可以显式地制定一次读取的数据记录数量。

tf.decode_csv()方法中的record_defaults参数,是为了给数据记录中的某些不合法或不存在的字段填充默认值,以确保程序正常执行。注意,以上的Reader.read()方法还有tf.decode_csv()方法返回的都是graph ops,而不是real data,用户需要通过session.run()才能获得data。

TFRecords file

TFRecords文件存储的是有结构的序列化字符块,他是TF推荐的standard file format。但是一般数据集的annotations都没有采用该类数据格式,因此我们在这里不做更多介绍。

Any format file

自由格式文件是用户自定义的二进制文件。它的存储对象是字符串,每条数据记录都是一个固定长度的字节块。因此如果要想正确识别和转换二进制文件中的数据记录,必须使用tf.FixedLengthRecordReader()读取二进制文件中固定长度的字节块,然后使用tf.decode_raw()方法将读取的字符串转换为张量。tf.FixedLengthRecordReader()tf.TextLineReader()均继承自ReaderBase类,都支持一次读取多条记录的方法。

tf.decode_raw()方法的功能是将字符串转换为张量,其prototype如下所示:

tf.decode_raw(bytes, out_type, little_endian=None, name=None)

创建样例队列

在“CSV file”小结,我们得到了year, month, income, outgo四个特征张量。在会话执行时,为了使计算机任务顺利获得输入数据,我们需要使用tf.train.start_queue_runners()方法启动执行入队操作的所有线程,具体包括文件名入队到filename_queue的操作,样例入队到样例队列的操作。

示例代码如下:

init_op = tf.global_variables_initializer()
with tf.Session() as sess:
	sess.run(init_op)
	# 启动所有执行入队操作的后台线程
	tf.train.start_queue_runners(sess=sess)
	for i in range(2):
		example = sess.run(features)
		print(example)

上述代码中的features是之前创建的数据读取、解析的操作。同时上述代码并不适合生产环境,因为其较差的容错性:无人管理队列操作后台线程的生命周期,任何线程出现异常都会导致程序崩溃。为了解决该问题,可以使用tf.train.Coordinator()方法构建管理多线程生命周期的协调器。它会监控TF所有后台线程,但其中某个线程出现异常时,Coordinator.should_stop()将返回True,使for循环结束。然后执行finally中Coordinator.request_stop()方法,请求所有线程安全退出。

需要注意的是,使用Coordinator管理multi-threads之前,需要先执行tf.local_variables_initializer()方法对其进行初始化。所以使用tf.group()方法将tf.local_variables_initializer()tf.global_variables_initializer()聚合生成整个程序的初始化操作init_op。

示例代码如下:

import tensorflow as tf
# create filename_queue and setting epochs = 5
filename_list = ['file1.csv', 'file2.csv']
filename_queue = tf.train.string_input_producer(filename_list , num_epochs=5)

...

# aggregate local and global initialization ops
init_op = tf.group(tf.local_variables_initializer(),
					tf.global_variables_initializer())

with tf.Session() as sess:
	init_op.run()
	coord = tf.train.Coordinator()
	threads = tf.train.start_queue_runners(sess, coord=coord)
	print("=> Threads: ", threads)
	try:
		for i in range(10):
			if not coord.should_stop():
				example = sess.run(features)
				print(example)
	except tf.errors.OutofRangeError:
		print('=> Catch OutofRangeError')
	finally:
		# request to stop all the threads in background
		coord.request_stop()
		print('=> Finish reading ...')
	coord.join(threads)

创建批样例数据

通过上节内容,我们成功获得了数据样例,但是需要将这些样例聚合成批数据才能用于模型训练、评估和推理使用。TF提供的tf.train.shuffle_batch()方法不仅能够使用样例创建批数据,而且能够在打包过程中打乱样例顺序。增加随机性。

示例代码:

filename_queue = ...
examples = ...

# batch queue settings
batch_size = 16
min_after_dequeque = 10000 # 样例队列中出队的样例个数
capacity = min_after_dequeque + 3 * batch_size # 批数据队列容量
# create batch queue
batch_queue = tf.train.shuffle_batch([examples], 
									batch_size=batch_size, 
									capacity=capacity, 
									min_after_dequeque=min_after_deque)

tf.train.shuffle_batch()除了上面使用的参数外,常用的还有设置线程个数的num_threads参数,设置随机化种子的seed参数,以及设置多条样例入队的enqueue_many参数。

填充数据节点

使用批数据训练的模型基本上都是用feed数据节点的方法,他不需要读取完整的数据集,有效减少了内存开销。同时,基于并行输入流水线的数据读取方法保证了实时性,与将全部数据预加载到内存中,对训练结果没有明显差距。

CIFAR-10数据集示例

CIFAR-10数据集总共包含60000张32x32x3的图像,图片总共有10类,每一类6000张图片, 下载地址。整个数据集被分为6个批数据,每一批数据包含1W张图片,其中5W张用于模型训练,1W张用于模型测试。

在CIFAR-10数据集中,一条数据记录由类别标签和图像数据两部分组成。单张图片需要3072个字节,类别标签1个字节。因此,CIFAR-10数据集中单条记录占用3073个字节,他们以二进制数据文件格式存储。

示例代码:

import tensorflow as tf

# label length
LABEL_BYTES = 1
# image size
IMAGE_SIZE = 32
# image channel
IMAGE_CHANNEL = 3
# image data length
IMAGE_BYTES = IMAGE_SIZE * IMAGE_SIZE * IMAGE_CHANNEL
# classes num
NUM_CLASSES = 10

def read_cifar10(data_file, batch_size):
	"""
	input params:
		data_file: CIFAR-10 data file
		batch_size: the size of batch
	returns:
		images: images batch following format [batch_size, IMAGE_SIZE, IMAGE_SIZE]
		labels: labels batch following format [batch_size, NUM_CLASSES]
	"""
	record_bytes = LABEL_BYTES + IMAGE_BYTES
	# create filename list
	data_files = tf.gfile.Glob(data_file)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Dongz__

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值