tensorflow dataset tfrecord的写入和读取

tensorflow的版本:1.13.1

针对tensorflow 2以及上的版本,可以参考tensorflow的官方介绍:https://www.tensorflow.org/tutorials/load_data/tfrecord

    def create_int_feature(self, values):
        feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
        return feature
    """
    以下三个函数可以参考tensorflow的官方文档:https://www.tensorflow.org/tutorials/load_data/tfrecord
    """
    def _bytes_feature(value):
        """Returns a bytes_list from a string / byte."""
        if isinstance(value, type(tf.constant(0))):
            value = value.numpy()  # BytesList won't unpack a string from an EagerTensor.
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    def _float_feature(value):
        """Returns a float_list from a float / double."""
        return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

    def _int64_feature(value):
        """Returns an int64_list from a bool / enum / int / uint."""
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


    def convert_data_to_tfrecord(self, examples: List[NerExample] =None, tfrecord_path: str =None):

        if os.path.exists(tfrecord_path):
            shutil.rmtree(tfrecord_path)

        writer = tf.python_io.TFRecordWriter(tfrecord_path)

        for example in examples:
            text_ids = example.word_ids
            label_ids = example.label_ids

            features = collections.OrderedDict()
            features["text_ids"] = self.create_int_feature(text_ids)
            features["label_ids"] = self.create_int_feature(label_ids)
            tf_example = tf.train.Example(features=tf.train.Features(feature=features))
            writer.write(tf_example.SerizlizeToString())
        writer.close()




    def get_train_tfrecord(self, tfrecord_path=None, num_epochs=1, batch_size=16, shuffle=True):

        dataset = tf.data.TFRecordDataset(tfrecord_path)
        feature_dict = {"text_ids": tf.VarLenFeature(tf.int64),
                        "label_ids": tf.VarLenFeature(tf.int64)}

        # feature_dict = {"text_ids": tf.FixedLenFeature([self.max_len], dtype=tf.int64), "label_ids": tf.FixedLenFeature([self.max_len], dtype=tf.int64)}

        def parse_one_example(example):
            example = tf.parse_single_example(example, feature_dict)
            text_ids = tf.to_int32(example["text_ids"])
            label_ids = tf.to_int32(example["label_ids"])
            return text_ids, label_ids

        dataset = dataset.map(parse_one_example)

        if shuffle:
            dataset = dataset.shuffle(buffer_size=1000, reshuffle_each_iteration=True)

        dataset = dataset.padded_batch(batch_size, padded_shapes=([None], [None]), padding_values=0)
        dataset = dataset.repeat(num_epochs)
        iterator = dataset.make_one_shot_iterator()
        batch_data = iterator.get_next()
        return batch_data  # batch_text_ids = batch_data[0], batch_label_ids = batch_data[1]

    def get_train_iterator(self, tfrecord_path=None, num_epochs=1, batch_size=16, shuffle=True):

        feature_dict = {"text_ids": tf.VarLenFeature(tf.int64),
                        "label_ids": tf.VarLenFeature(tf.int64)}

        def parse_one_example(example):
            example = tf.parse_single_example(example, feature_dict)
            text_ids = tf.to_int32(example["text_ids"])
            label_ids = tf.to_int32(example["label_ids"])
            example["text_ids"] = text_ids
            example["label_ids"] = label_ids
            return example

        dataset = tf.data.TFRecordDataset(tfrecord_path)
        if shuffle:
            dataset = dataset.shuffle(buffer_size=1000, reshuffle_each_iteration=True)
        dataset = dataset.repeat(num_epochs)

        dataset = dataset.apply(tf.data.experimental.map_and_batch(lambda record: parse_one_example(record), batch_size=batch_size, drop_remainder=False))

        iterator = dataset.make_one_shot_iterator()
        batch_data = iterator.get_next()
        return batch_data

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值