鄙人是一位TensorFlow萌新,在尝试生成tfrecord文件时发生错误,请求大佬解答,代码如下:
def serialize_example(x, y):
"""Converts x, y to tf.train.Example and serialize"""
input_feautres = tf.train.FloatList(value = x)
label = tf.train.FloatList(value = y)
features = tf.train.Features(
feature = {
"input_features": tf.train.Feature(
float_list = input_feautres),
"label": tf.train.Feature(float_list = label)
}
)
example = tf.train.Example(features = features)
return example.SerializeToString()
def csv_dataset_to_tfrecords(base_filename, dataset,
n_shards, steps_per_shard,
compression_type = None):
options = tf.io.TFRecordOptions(
compression_type = compression_type)
all_filenames = []
for shard_id in range(n_shards):
filename_fullpath = '{}_{:05d}-of-{:05d}'.format(
base_filename, shard_id, n_shards)
with tf.io.TFRecordWriter(filename_fullpath, options) as writer:
for x_batch, y_batch in dataset.take(steps_per_shard):
for x_example, y_example in zip(x_batch, y_batch):
writer.write(
serialize_example(x_example, y_example))
all_filenames.append(filename_fullpath)
return all_filenames
然后出错就出错在下面这里:
n_shards = 20
train_steps_per_shard = 11610 // batch_size // n_shards
valid_steps_per_shard = 3880 // batch_size // n_shards
test_steps_per_shard = 5170 // batch_size // n_shards
output_dir = "generate_tfrecords"
if not os.path.exists(output_dir):
os.mkdir(output_dir)
train_basename = os.path.join(output_dir, "train")
valid_basename = os.path.join(output_dir, "valid")
test_basename = os.path.join(output_dir, "test")
train_tfrecord_filenames = csv_dataset_to_tfrecords(
train_basename, train_set, n_shards, train_steps_per_shard, None)
valid_tfrecord_filenames = csv_dataset_to_tfrecords(
valid_basename, valid_set, n_shards, valid_steps_per_shard, None)
test_tfrecord_fielnames = csv_dataset_to_tfrecords(
test_basename, test_set, n_shards, test_steps_per_shard, None)
求求大佬指点,谢谢了