将tfrecords数据读取 存储到pkl中

import tensorflow as tf
import os
import pickle
import collections

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
train_precomputed_file = "bert_joint/bert-joint-baseline/nq-train.tfrecords-00000-of-00001"
# train_precomputed_file = "bert_model_output/eval.tf_record"

# save_pkl_path = "simplified-nq-pkl/test.pkl"
save_pkl_path = "simplified-nq-pkl/nq-train-tfrecords.pkl"
seq_length = 512  # 384
is_training = True
drop_remainder = False
batch_size = 1000  # 处理数据batch_size


def _parse_record(example_photo):
    name_to_features = {
        "unique_ids": tf.io.FixedLenFeature([], tf.int64),
        "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
        "input_mask": tf.io.FixedLenFeature([seq_length], tf.int64),
        "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
    }

    if is_training:
        name_to_features["start_positions"] = tf.io.FixedLenFeature([], tf.int64)
        name_to_features["end_positions"] = tf.io.FixedLenFeature([], tf.int64)
        name_to_features["answer_types"] = tf.io.FixedLenFeature([], tf.int64)

    parsed_features = tf.io.parse_single_example(example_photo, features=name_to_features)
    return parsed_features


def save_pkl(save_pkl_path, features):
    with open(save_pkl_path, "wb") as f:
        pickle.dump(features, f)  # 永久保存成 pkl形式
        print("保存成功!{}".format(len(features)))


def read_test(input_file):
    # 用dataset读取TFRecords文件
    dataset = tf.data.TFRecordDataset(input_file)
    dataset = dataset.map(_parse_record)
    dataset = dataset.batch(batch_size=batch_size, drop_remainder=False)  # 按批次读取吗
    iterator = dataset.make_one_shot_iterator()   # 迭代
    with tf.Session() as sess:
        count = 0
        features_plk_list_all = []
        try:
            featurehh = iterator.get_next()
            while True:
                features = sess.run(featurehh)
                # print(type(features))  # (batch_size,)
                unique_ids = features['unique_ids']  # (batch_size, seq_max)
                input_ids = features["input_ids"]
                input_mask = features["input_mask"]
                segment_ids = features["segment_ids"]
                if is_training:
                    start_positions = features["start_positions"]
                    end_positions = features["end_positions"]
                    answer_types = features["answer_types"]
                features_plk_list = []
                len_size = len(unique_ids)
                for i in range(len_size):
                    item = collections.OrderedDict()
                    item["unique_ids"] = [unique_ids[i]]
                    item["input_ids"] = input_ids[i]
                    item["input_mask"] = input_mask[i]
                    item["segment_ids"] = segment_ids[i]
                    if is_training:
                        item["start_positions"] = [start_positions[i]]
                        item["end_positions"] = [end_positions[i]]
                        item["answer_types"] = [answer_types[i]]
                    features_plk_list.append(item)
                assert len(features_plk_list) == len_size, (len(features_plk_list), len_size)
                # save_pkl(save_pkl_path, features_plk_list)
                print(count)
                features_plk_list_all.extend(features_plk_list)
                count = count + 1
                # if count >= 100:
                #     break
        except tf.errors.OutOfRangeError:
            # print(features)
            print("异常")
            pass
        save_pkl(save_pkl_path, features_plk_list_all)
        print("数据读取完毕: 数量:{}".format(len(features_plk_list_all)))

read_test(train_precomputed_file)

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值