tensorflow的版本:1.13.1
针对tensorflow 2以及上的版本,可以参考tensorflow的官方介绍:https://www.tensorflow.org/tutorials/load_data/tfrecord
def create_int_feature(self, values):
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return feature
"""
以下三个函数可以参考tensorflow的官方文档:https://www.tensorflow.org/tutorials/load_data/tfrecord
"""
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def convert_data_to_tfrecord(self, examples: List[NerExample] =None, tfrecord_path: str =None):
if os.path.exists(tfrecord_path):
shutil.rmtree(tfrecord_path)
writer = tf.python_io.TFRecordWriter(tfrecord_path)
for example in examples:
text_ids = example.word_ids
label_ids = example.label_ids
features = collections.OrderedDict()
features["text_ids"] = self.create_int_feature(text_ids)
features["label_ids"] = self.create_int_feature(label_ids)
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerizlizeToString())
writer.close()
def get_train_tfrecord(self, tfrecord_path=None, num_epochs=1, batch_size=16, shuffle=True):
dataset = tf.data.TFRecordDataset(tfrecord_path)
feature_dict = {"text_ids": tf.VarLenFeature(tf.int64),
"label_ids": tf.VarLenFeature(tf.int64)}
# feature_dict = {"text_ids": tf.FixedLenFeature([self.max_len], dtype=tf.int64), "label_ids": tf.FixedLenFeature([self.max_len], dtype=tf.int64)}
def parse_one_example(example):
example = tf.parse_single_example(example, feature_dict)
text_ids = tf.to_int32(example["text_ids"])
label_ids = tf.to_int32(example["label_ids"])
return text_ids, label_ids
dataset = dataset.map(parse_one_example)
if shuffle:
dataset = dataset.shuffle(buffer_size=1000, reshuffle_each_iteration=True)
dataset = dataset.padded_batch(batch_size, padded_shapes=([None], [None]), padding_values=0)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()
batch_data = iterator.get_next()
return batch_data # batch_text_ids = batch_data[0], batch_label_ids = batch_data[1]
def get_train_iterator(self, tfrecord_path=None, num_epochs=1, batch_size=16, shuffle=True):
feature_dict = {"text_ids": tf.VarLenFeature(tf.int64),
"label_ids": tf.VarLenFeature(tf.int64)}
def parse_one_example(example):
example = tf.parse_single_example(example, feature_dict)
text_ids = tf.to_int32(example["text_ids"])
label_ids = tf.to_int32(example["label_ids"])
example["text_ids"] = text_ids
example["label_ids"] = label_ids
return example
dataset = tf.data.TFRecordDataset(tfrecord_path)
if shuffle:
dataset = dataset.shuffle(buffer_size=1000, reshuffle_each_iteration=True)
dataset = dataset.repeat(num_epochs)
dataset = dataset.apply(tf.data.experimental.map_and_batch(lambda record: parse_one_example(record), batch_size=batch_size, drop_remainder=False))
iterator = dataset.make_one_shot_iterator()
batch_data = iterator.get_next()
return batch_data