tf.data.Dataset使用

一、特点:

  • 数据集可以看作是计算图上的一个节点,因此Dataset的操作是基于Tensor的,例如map()函数

  • 该类在eager模式和非eager模式下的使用是不一致的。

  • 由于训练数据通常无法全部写入内存中,这样无法进行shuffle()和batch(),所以从数据集中读取数据时需要使用一个迭代器(iterator)按顺序进行读取。

  • drop_remainder = True
      这里设置成True, 会丢弃最后多余的,不是一个batch的数据。

二、构建Dataset:

2.1 from_tensor_slices()

2.1.1 函数参数

from_tensor_slices(
    tensors
)

注意:如果tensors包含numpy array,并且没有启用eager 模式,value将会作为 tf.constant 不断嵌入计算图中,导致计算图越来越大。

参数可以是Dict, 也可以是Tuple:

dict : {"token_ids": [[1, 2, 3, 0, 0],  [5, 6, 7, 8, 9],  [1, 1, 1, 1, 1]], "labels":[1, 2, 3]} 
tuple: (input_array, label_array)

2.1.2 例子

  • 数据集确定.
    数据集确定情况下, 可以使用one_shot_iterator简单遍历数据集,而不需要初始化数据集。代码如下所示.
import tensorflow as tf 

# 从数组中创建数据集. 
input_data = [1, 2, 3, 5, 8]
dataset = tf.data.Dataset.from_tensor_slices(input_data)

# 定义迭代器, 遍历数据集. 
iterator = dataset.make_one_shot_iterator()

# 获取一份数据. 
x = iterator.get_next()
y = x * x 

with tf.Session() as sess:
	# 如果不做batch()操作, 默认一次get_next()得一份数据. 
	for i in range(len(input_data)):  
		print(sess.run(y))
  • 数据集不确定
    数据集不确定的意思是: 使用tf.placeholder()来占位数据集. 此时迭代器必须使用initializable_iterator动态初始化数据集.
import tensorflow as tf 

# 使用tf.placeholder. 
input_data = tf.placeholder(dtype=tf.float32) 
dataset = tf.data.Dataset.from_tensor_slices(input_data)
dataset = dataset.shuffle(4).batch(4)   # 注意想要batch, 必须保证每条数据的dimension相同. 

# 定义迭代器, 遍历数据集. 
iterator = dataset.make_initializable_iterator()

# 获取一份数据. 
x = iterator.get_next()
y = x * x 

with tf.Session() as sess:
	feed_data = [[1, 1, 1], [2, 2, 2], [3, 3, 3]]
	# 初始化训练数据的迭代器
	sess.run(iterator.initializer, feed_dict={input_data: feed_data})

	# 如果不做batch()操作, 默认一次get_next()得一份数据. 
	while True:
		try:
			print_y = sess.run(y)
			print(print_y)
		except tf.errors.OutOfRangeError:
			break

2.2 TextLineDataset()

2.2.1 函数参数

tf.data.TextLineDataset(
    filenames, compression_type=None, buffer_size=None, num_parallel_reads=None
)
  • filenames:
    tf.string or tf.data.Dataset包含一个或多个filenames,然而实际上[“file_1”, “file_2”]好像也行。

2.2.2 例子

import tensorflow as tf 

# 从文本文件创建数据集
input_files = ["a.txt", "b.txt"]
dataset = tf.data.TextLineDataset(input_files)

# 定义迭代器用于遍历数据集
iterator = dataset.make_one_shot_iterator()

# 得到数据, 每次get的是文件中的一行
x = iterator. get_next()

with tf.Session() as sess : 
	for in range(3) : 
		print(sess. run(x))

2.3 TFRecordDataset()

2.3.1 函数参数

tf.data.TFRecordDataset(
    filenames, compression_type=None, buffer_size=None, num_parallel_reads=None
)
  • filenames :
    也是tensor类型, 可以placeholder占用,传参。

2.3.2 例子

以下是加载TFRecord文件,作为Dataset的例子:
注意: 使用tf.FixedLenFeature()需要指定List的长度.
而使用tf.VarLenFeature()无需指定List的长度,也不需要[]占位,只需要指定type即可。

注意: 使用tf.VarLenFeature()得到的是SparseTensor, 所以之后获取值的时候,需要访问.values属性。

def parser_single_record(record):
    """
    解析读入的一个样例.
    :param record:   		tfrecord中的一条数据. 
    :return:
    """
    features = tf.parse_single_example(
        record,
        features={
            "pixels": tf.FixedLenFeature([pixels_len], tf.int64),  # tf.FixedLenFeature 这种方法解析的结果为一个Tensor.
            "label": tf.FixedLenFeature([label_len], tf.int64),  # tf.VarLenFeature 这种方法解析的结果为SparseTensor.
            "image_raw": tf.FixedLenFeature([image_raw_len], tf.string)
        }
    )

    return features["pixels"], features["label"], features["image_raw"]
 
 
if __name__ == "__main__":
	input_files = tf.placeholder(tf.string)
	dataset = tf.data.TFRecordDataset(input_files)

	# 对dataset中元素进行解析.
    dataset = dataset.map(parser_single_record) 

    # 定义iterator
    iterator = dataset.make_initializable_iterator()
    pixel_features, label_features, image_features = iterator.get_next()

    with tf.Session() as sess:
        sess.run(iterator.initializer, feed_dict={input_files: ["./input_file1", "./input_file2"]})

        while True:
            try:
                sess.run([pixel_features, label_features, image_features])
            except tf.errors.OutOfRangeError:
                break

2.3.3 创建TFRecord文件

意义在于: 使用TFRecord格式统一存储数据,因为当数据来源更加复杂、每一个样例中的信息更加丰富之后,以往的方式很难有效管理。

TFRecord 文件中的数据都是通过tf.train.Example格式存储的,下面是tf.train.Example的定义。

message Example {
	Features features = 1;
} ;

message Features {
	map<string, Feature> feature = 1;
}; 

message Feature {
	字符串列表BytesList, 实数列表FloatList, 或整数列表Int64List中的一种。
}

以下是实例:
注意: tf.train.Int64List(value) 以及其他List的 value参数必须是一维List, 如果是嵌套List必须拉平。
使用tf.train.BytesList(value = [str_txt.encoder()])对于str类型,传入BytesList必须调用encoder

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))   
    # 这个传入的value是单值, 也就是说参数value对应的必须是1维list, 如果嵌套需要拉平.


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def write_tfrecord():
    # TFRecord文件的地址
    filename = "/path/to/output.tfrecords"
    # 创建一个writer写TFRecord文件
    writer = tf.python_io.TFRecordWriter(filename)

    for index in range(num_examples):
        # 将图像矩阵转化为一个字符串.
        image_raw = images[index].tostring()

        # 将一条数据转换为Example.
        example = tf.train.Example(features=tf.train.Features(feature={
            "pixels": _int64_feature(pixels),
            "labels": _int64_feature(np.argmax(labels[index])),
            "image_raw": _bytes_feature(image_raw)
        }))

        # 写入TFRecord文件.
        writer.write(example.SerializeToString())

    writer.close()

三、使用Dataset

3.1 map()

map将处理后的数据包装成一个新的数据集返回,可以直接调用其他高层操作。

注意: map的操作对象是Tensor, 所以对应函数也是基于Tensor进行处理的, 如果想使用numpy方法, 需要使用tf.py_function().

但我这边,使用tf.py_function实现word2idx,一直没成功。

3.1.2 例子

dataset = dataset.map(lambda x, y: 
					  (tf.reshape(x, [-1, hidden_size]), y, x.shape[1], y.shape[1]))

这样使用lambda对dataset中的元素逐条处理, 注意后面是tuple()
通过上述map, 参数由2个变成了4个。

3.2 shuffle()

shuffle在内部使用一个缓存区保存buffer_size条数据,每读入一条新数据,从这个缓存区中随机选择一条数据进行输出,缓存区的大小越大,随机的性能越好,占用内存越多

3.2.1 函数参数

shuffle(
    buffer_size, seed=None, reshuffle_each_iteration=None
)
  • reshuffle_each_iteration:
    Boolen类型.

返回的也是Dataset.

3.3 batch()

3.3.1 函数参数

batch(
    batch_size, drop_remainder=False
)

3.4 repeat()

通过repeat方法训练多个epoch

3.4.1 函数参数

batch(
    batch_size, drop_remainder=False
)

3.4.2 例子

注意: 如果数据集在repeat前已经进行了shuffle操作,输出的每个epoch中随机shuffle的结果并不会相同。

用了repeat,就没有for epoch in range(epochs)了。

4. 附

data = tf.data.Dataset.from_tensor_slices(data_features).shuffle(shuffle_size).batch(
            batch_size, drop_remainder=True)
    
data_iterator = data.make_initializable_iterator()
batch_data = iterator.get_next()

with tf.Session(config=config) as sess:
	sess.run(tf.global_variables_initializer())
	
	for epoch in range(FLAGS.num_train_epoches):
		
		sess.run(data_iterator.initializer)
		
        while True:
        	try:
            	data = sess.run(train_batch_data)
            except tf.errors.OutOfRangeError:
                break
            feed = {
                input_ids: data["token_ids"],
                target: data["labels"],
                embedding_word2vec_inputs: embedding_word2vec_matrix}
                
                _, loss = sess.run([optimizer, summary_losses], feed_dict=feed)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值