一、特点:
-
数据集可以看作是计算图上的一个节点,因此Dataset的操作是基于Tensor的,例如map()函数
-
该类在eager模式和非eager模式下的使用是不一致的。
-
由于训练数据通常无法全部写入内存中,这样无法进行shuffle()和batch(),所以从数据集中读取数据时需要使用一个迭代器(iterator)按顺序进行读取。
-
drop_remainder = True
这里设置成True, 会丢弃最后多余的,不是一个batch的数据。
二、构建Dataset:
2.1 from_tensor_slices()
2.1.1 函数参数
from_tensor_slices(
tensors
)
注意:如果tensors包含numpy array,并且没有启用eager 模式,value将会作为 tf.constant 不断嵌入计算图中,导致计算图越来越大。
参数可以是Dict, 也可以是Tuple:
dict : {"token_ids": [[1, 2, 3, 0, 0], [5, 6, 7, 8, 9], [1, 1, 1, 1, 1]], "labels":[1, 2, 3]}
tuple: (input_array, label_array)
2.1.2 例子
- 数据集确定.
数据集确定情况下, 可以使用one_shot_iterator简单遍历数据集,而不需要初始化数据集。代码如下所示.
import tensorflow as tf
# 从数组中创建数据集.
input_data = [1, 2, 3, 5, 8]
dataset = tf.data.Dataset.from_tensor_slices(input_data)
# 定义迭代器, 遍历数据集.
iterator = dataset.make_one_shot_iterator()
# 获取一份数据.
x = iterator.get_next()
y = x * x
with tf.Session() as sess:
# 如果不做batch()操作, 默认一次get_next()得一份数据.
for i in range(len(input_data)):
print(sess.run(y))
- 数据集不确定
数据集不确定的意思是: 使用tf.placeholder()来占位数据集. 此时迭代器必须使用initializable_iterator动态初始化数据集.
import tensorflow as tf
# 使用tf.placeholder.
input_data = tf.placeholder(dtype=tf.float32)
dataset = tf.data.Dataset.from_tensor_slices(input_data)
dataset = dataset.shuffle(4).batch(4) # 注意想要batch, 必须保证每条数据的dimension相同.
# 定义迭代器, 遍历数据集.
iterator = dataset.make_initializable_iterator()
# 获取一份数据.
x = iterator.get_next()
y = x * x
with tf.Session() as sess:
feed_data = [[1, 1, 1], [2, 2, 2], [3, 3, 3]]
# 初始化训练数据的迭代器
sess.run(iterator.initializer, feed_dict={input_data: feed_data})
# 如果不做batch()操作, 默认一次get_next()得一份数据.
while True:
try:
print_y = sess.run(y)
print(print_y)
except tf.errors.OutOfRangeError:
break
2.2 TextLineDataset()
2.2.1 函数参数
tf.data.TextLineDataset(
filenames, compression_type=None, buffer_size=None, num_parallel_reads=None
)
- filenames:
tf.string
ortf.data.Dataset
包含一个或多个filenames,然而实际上[“file_1”, “file_2”]好像也行。
2.2.2 例子
import tensorflow as tf
# 从文本文件创建数据集
input_files = ["a.txt", "b.txt"]
dataset = tf.data.TextLineDataset(input_files)
# 定义迭代器用于遍历数据集
iterator = dataset.make_one_shot_iterator()
# 得到数据, 每次get的是文件中的一行
x = iterator. get_next()
with tf.Session() as sess :
for in range(3) :
print(sess. run(x))
2.3 TFRecordDataset()
2.3.1 函数参数
tf.data.TFRecordDataset(
filenames, compression_type=None, buffer_size=None, num_parallel_reads=None
)
- filenames :
也是tensor类型, 可以placeholder占用,传参。
2.3.2 例子
以下是加载TFRecord文件,作为Dataset的例子:
注意: 使用tf.FixedLenFeature()
需要指定List的长度.
而使用tf.VarLenFeature()
无需指定List的长度,也不需要[]
占位,只需要指定type即可。
注意: 使用tf.VarLenFeature()
得到的是SparseTensor
, 所以之后获取值的时候,需要访问.values属性。
def parser_single_record(record):
"""
解析读入的一个样例.
:param record: tfrecord中的一条数据.
:return:
"""
features = tf.parse_single_example(
record,
features={
"pixels": tf.FixedLenFeature([pixels_len], tf.int64), # tf.FixedLenFeature 这种方法解析的结果为一个Tensor.
"label": tf.FixedLenFeature([label_len], tf.int64), # tf.VarLenFeature 这种方法解析的结果为SparseTensor.
"image_raw": tf.FixedLenFeature([image_raw_len], tf.string)
}
)
return features["pixels"], features["label"], features["image_raw"]
if __name__ == "__main__":
input_files = tf.placeholder(tf.string)
dataset = tf.data.TFRecordDataset(input_files)
# 对dataset中元素进行解析.
dataset = dataset.map(parser_single_record)
# 定义iterator
iterator = dataset.make_initializable_iterator()
pixel_features, label_features, image_features = iterator.get_next()
with tf.Session() as sess:
sess.run(iterator.initializer, feed_dict={input_files: ["./input_file1", "./input_file2"]})
while True:
try:
sess.run([pixel_features, label_features, image_features])
except tf.errors.OutOfRangeError:
break
2.3.3 创建TFRecord文件
意义在于: 使用TFRecord格式统一存储数据,因为当数据来源更加复杂、每一个样例中的信息更加丰富之后,以往的方式很难有效管理。
TFRecord 文件中的数据都是通过tf.train.Example格式存储的,下面是tf.train.Example的定义。
message Example {
Features features = 1;
} ;
message Features {
map<string, Feature> feature = 1;
};
message Feature {
字符串列表BytesList, 实数列表FloatList, 或整数列表Int64List中的一种。
}
以下是实例:
注意: tf.train.Int64List(value) 以及其他List的 value参数
必须是一维List, 如果是嵌套List必须拉平。
使用tf.train.BytesList(value = [str_txt.encoder()])
对于str类型,传入BytesList必须调用encoder
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# 这个传入的value是单值, 也就是说参数value对应的必须是1维list, 如果嵌套需要拉平.
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def write_tfrecord():
# TFRecord文件的地址
filename = "/path/to/output.tfrecords"
# 创建一个writer写TFRecord文件
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
# 将图像矩阵转化为一个字符串.
image_raw = images[index].tostring()
# 将一条数据转换为Example.
example = tf.train.Example(features=tf.train.Features(feature={
"pixels": _int64_feature(pixels),
"labels": _int64_feature(np.argmax(labels[index])),
"image_raw": _bytes_feature(image_raw)
}))
# 写入TFRecord文件.
writer.write(example.SerializeToString())
writer.close()
三、使用Dataset
3.1 map()
map将处理后的数据包装成一个新的数据集返回,可以直接调用其他高层操作。
注意: map的操作对象是Tensor, 所以对应函数也是基于Tensor进行处理的, 如果想使用numpy方法, 需要使用tf.py_function()
.
但我这边,使用tf.py_function实现word2idx,一直没成功。
3.1.2 例子
dataset = dataset.map(lambda x, y:
(tf.reshape(x, [-1, hidden_size]), y, x.shape[1], y.shape[1]))
这样使用lambda对dataset中的元素逐条处理, 注意后面是tuple()
通过上述map, 参数由2个变成了4个。
3.2 shuffle()
shuffle在内部使用一个缓存区保存buffer_size条数据,每读入一条新数据,从这个缓存区中随机选择一条数据进行输出,缓存区的大小越大,随机的性能越好,占用内存越多
3.2.1 函数参数
shuffle(
buffer_size, seed=None, reshuffle_each_iteration=None
)
- reshuffle_each_iteration:
Boolen类型.
返回的也是Dataset.
3.3 batch()
3.3.1 函数参数
batch(
batch_size, drop_remainder=False
)
3.4 repeat()
通过repeat方法训练多个epoch
3.4.1 函数参数
batch(
batch_size, drop_remainder=False
)
3.4.2 例子
注意: 如果数据集在repeat前已经进行了shuffle操作,输出的每个epoch中随机shuffle的结果并不会相同。
用了repeat,就没有for epoch in range(epochs)
了。
4. 附
data = tf.data.Dataset.from_tensor_slices(data_features).shuffle(shuffle_size).batch(
batch_size, drop_remainder=True)
data_iterator = data.make_initializable_iterator()
batch_data = iterator.get_next()
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(FLAGS.num_train_epoches):
sess.run(data_iterator.initializer)
while True:
try:
data = sess.run(train_batch_data)
except tf.errors.OutOfRangeError:
break
feed = {
input_ids: data["token_ids"],
target: data["labels"],
embedding_word2vec_inputs: embedding_word2vec_matrix}
_, loss = sess.run([optimizer, summary_losses], feed_dict=feed)