读写tfrecord文件


在训练模型的时候,一般会将数据预处理转换成tfrecord格式,负责I/O操作的CPU和进行数值运行计算的GPU相互之间可以并行工作,保证GPU高的利用率。以下是对特征是定长和变长读写tfrecord方式。

1 写tfrecor方式

一般会将数据按照模型训练所需要的方式对输入x和label标签进行tfrecord格式转换。主要有定长和变长两种方式,根据实际应用和需求决定。若输入的每个example的input 是变长的,比如每个example的输入特征索引个数不是相同的,则可以按照变长的方式转换,否则按照定长的方式转换。

1.1 变长特征转tfrecord

import collections
writer = tf.python_io.TFRecordWriter('data.tfrecord')
def toTF(data):
	''' 
	data是一个dict,假设其中key有input_x和input_y,
	对应的value是索引list
	'''
	features = collections.OrderedDict()
	input_x = tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_x"])))
	features["input_x"] = tf.train.FeatureList(feature=input_x)
	input_y = tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_y"])))
	features["input_y"] = tf.train.FeatureList(feature=input_y)
	sequence_example = tf.train.SequenceExample(feature_lists=tf.train.FeatureLists(feature_list=features))
	writer.write(sequence_example.SerializeToString())

以下方式实现与上面方式等价:

def toTF_v2(data)
	sequence_example = tf.train.SequenceExample()
	input_x = sequence_example.feature_lists.feature_list["input_x"]
	input_y = sequence_example.feature_lists.feature_list["input_y"]
	for x in data["input_x"]:
		input_x.feature.add().int64_list.value.append(x)
	for y in data["input_y"]:
		input_y.feature.add().int64_list.value.append(y)
	writer.write(sequence_example.SerializeToString())

1.2 定长特征转tfrecord

def toTF_fixed(data):
	features = collections.OrderedDict()
	features["input_x"]= tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_x")))
	features["input_y"]= tf.train.Feature(int64_list=tf.train.Int64List(value=list(data["input_y")))
	example = tf.train.Example(features=tf.train.Features(feature=features))
	write.write(example.SerializeToString())

2 读tfrecord

和写trrecord一样,也分定长和变长方式,如果写tfrecord是定长方式,则读tfrecord也需要定长方式。读写方式需要保持一致。

2.1 变长方式读tfrecord

需要定义特征的格式,如果是变长则定义tf.FixedLenSequenceFeature类型特征

import tensorflow as tf
features = {
			'input_x': tf.FixedLenSequenceFeature([], tf.int64)
			'input_y': tf.FixedLenSequenceFeature([], tf.int64)
			}

2.2 定长方式读tfrecord

定长方式用tf.FixedLenFeature类型

seq_length = 10
features = {
		'input_x': tf.FixedLenFeature([seq_length], tf.int64).
		'input_y': tf.FixedLenFeature([seq_length], tf.int64
		}

3 从hdfs中读取批量tfrecord文件

当训练数据量级很大时,一般转tfrecord试用分布式方式处理数据,提高效率。训练模型的时候,可以从远程,例如hdfs上读取批量文件。以下是从hdfs上批量读取tfrecord文件。

def input_fn_builder(file_path, num_cpu_threads, seq_length, num_class, batch_size):
	'''
	其中file_path是hdfs上文件的路径,比如data目录下的所有tfrecord文件
	读的是定长的feature
	'''
	features = {
			'input_x': tf.FixedLenFeature([seq_length], tf.int64),
			'input_y': tf.FixedLenFeature([seq_length], tf.int64),
	}
	def _decode_record(record):
		# 一个样本解析
		example = tf.io.parse_single_example(record, features)
		multi_label_enc = tf.one_hot(indices=example["input_y"], depth=num_class)
		example["input_y"] = tf.reduce_sum(multi_label_enc, axis=0)
		return example

	def _decode_batch_record(batch_record):
		# 一个batch样本解析
		batch_example = tf.io.parse_example(serialized=batch_record, features=features)
		multi_label_enc = tf.one_hot(indices=batch_example["input_y"], depth=num_class)
		batch_example["input_y"] = tf.reduce_sum(multi_label_enc, axis=1)
		return batch_example

	def input_fn(params):
		# d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
		d = tf.data.Dataset.list_files(file_path)
		d = d.repeat()
		d = d.shuffle(buffer_size=100)
		d = d.appley(
			tf.contrib.data.parallel_interleave(
				tf.data.TFRecordDataset,
				sloppy=True,
				cycle_length=num_cpu_threads))
		d = d.apply(
			tf.contrib.data.map_and_batch(
					lambda record: _decode_record(record),
					batch_size = batch_size,
					num_parallel_batches=num_cpu_threads,
					drop_remainder=True))
		return d

	def input_fn_v2(params):
		d = tf.data.Dataset.list_files(file_path)
		d = d.interleave(lambda x: tf.data.TFRecordDataset(x), cycle_length=num_cpu_threads, block_length=128).\
		batch(batch_size).map(_decode_batch_record, num_parallel_calls=tf.data.experimental.AUTOTRUE).prefetch(
			tf.data.experimental.AUTOTUNE).repeat()
		return d
	return input_fn
	#return input_fn_v2

上面提供了两个解析函数,input_fn和input_fn_v2两种方式都可行,配合estimator方式训练,可以使得CPU读取数据与GPU训练数据之间可以并行处理,减少等待时间,提高GPU的利用率,加快训练速度。解析tfrecord文件时,有下面四种方式,根据自己具体的数据格式进行选择:

  • 解析单个样本,定长特征:tf.io.parse_single_example()
  • 解析单个样本,变长特征:tf.io.parse_single_sequence_example()
  • 解析批量样本,定长特征:tf.io.parse_example()
  • 解析批量样本,定长特征:tf.io.parse_sequence_example()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Scala是一种强大的编程语言,它支持读写各种数据格式,包括tfrecordtfrecord是一种用于存储大型数据集的二进制文件格式,被广泛用于训练深度学习模型。 要在Scala中读取tfrecord文件,我们可以使用TensorFlow的Scala API。首先,我们需要导入必要的依赖项,包括TensorFlow和相关的Scala库。 ```scala import org.tensorflow._ import org.tensorflow.example._ ``` 然后,我们可以使用TensorFlow的`TFRecordReader`类来读取tfrecord文件。首先,创建一个`TFRecordReader`对象,并指定要读取文件路径。 ```scala val reader = new TFRecordReader() reader.initialize(new TFRecordReader.Options().setInputFiles("path/to/tfrecord/file.tfrecord")) ``` 接下来,我们可以使用`reader.getNext`方法逐个读取tfrecord文件中的记录,并将其转换为TensorFlow的`Example`对象。 ```scala var record = reader.getNext() while (record != null) { val example = Example.parseFrom(record.toByteArray()) // 对Example对象进行处理,例如打印其中的特征值 println(example.getFeatures) // 读取下一个记录 record = reader.getNext() } ``` 类似地,在Scala中写入tfrecord文件也非常简单。我们可以使用TensorFlow的`TFRecordWriter`类来创建一个tfrecord文件,并将数据写入其中。 ```scala import java.nio.file.Files import java.nio.file.Paths // 创建一个tfrecord文件的输出流 val writer = new TFRecordWriter(Files.newOutputStream(Paths.get("path/to/output.tfrecord"))) // 创建一个Example对象,添加特征并写入文件 val exampleBuilder = Example.newBuilder() exampleBuilder.getFeaturesBuilder().putFeature("feature_key", Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(1)).build()) writer.write(exampleBuilder.build().toByteArray()) // 关闭输出流 writer.close() ``` 在这个示例中,我们创建了一个`TFRecordWriter`对象,并指定了要写入的tfrecord文件的路径。然后,我们创建一个`Example`对象,并使用`exampleBuilder`设置了一个特征。最后,我们通过调用`write`方法将`Example`对象写入tfrecord文件,并通过调用`close`方法关闭输出流。 通过使用TensorFlow的Scala API,我们可以很容易地读写tfrecord文件,并在Scala中处理大型数据集。 Scala 提供了丰富的库和功能,使得读写tfrecord文件变得非常方便和高效。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值