import tensorflow as tf
import os
import pickle
import collections
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
train_precomputed_file = "bert_joint/bert-joint-baseline/nq-train.tfrecords-00000-of-00001"
# train_precomputed_file = "bert_model_output/eval.tf_record"
# save_pkl_path = "simplified-nq-pkl/test.pkl"
save_pkl_path = "simplified-nq-pkl/nq-train-tfrecords.pkl"
seq_length = 512 # 384
is_training = True
drop_remainder = False
batch_size = 1000 # 处理数据batch_size
def _parse_record(example_photo):
name_to_features = {
"unique_ids": tf.io.FixedLenFeature([], tf.int64),
"input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
"input_mask": tf.io.FixedLenFeature([seq_length], tf.int64),
"segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
}
if is_training:
name_to_features["start_positions"] = tf.io.FixedLenFeature([], tf.int64)
name_to_features["end_positions"] = tf.io.FixedLenFeature([], tf.int64)
name_to_features["answer_types"] = tf.io.FixedLenFeature([], tf.int64)
parsed_features = tf.io.parse_single_example(example_photo, features=name_to_features)
return parsed_features
def save_pkl(save_pkl_path, features):
with open(save_pkl_path, "wb") as f:
pickle.dump(features, f) # 永久保存成 pkl形式
print("保存成功!{}".format(len(features)))
def read_test(input_file):
# 用dataset读取TFRecords文件
dataset = tf.data.TFRecordDataset(input_file)
dataset = dataset.map(_parse_record)
dataset = dataset.batch(batch_size=batch_size, drop_remainder=False) # 按批次读取吗
iterator = dataset.make_one_shot_iterator() # 迭代
with tf.Session() as sess:
count = 0
features_plk_list_all = []
try:
featurehh = iterator.get_next()
while True:
features = sess.run(featurehh)
# print(type(features)) # (batch_size,)
unique_ids = features['unique_ids'] # (batch_size, seq_max)
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
if is_training:
start_positions = features["start_positions"]
end_positions = features["end_positions"]
answer_types = features["answer_types"]
features_plk_list = []
len_size = len(unique_ids)
for i in range(len_size):
item = collections.OrderedDict()
item["unique_ids"] = [unique_ids[i]]
item["input_ids"] = input_ids[i]
item["input_mask"] = input_mask[i]
item["segment_ids"] = segment_ids[i]
if is_training:
item["start_positions"] = [start_positions[i]]
item["end_positions"] = [end_positions[i]]
item["answer_types"] = [answer_types[i]]
features_plk_list.append(item)
assert len(features_plk_list) == len_size, (len(features_plk_list), len_size)
# save_pkl(save_pkl_path, features_plk_list)
print(count)
features_plk_list_all.extend(features_plk_list)
count = count + 1
# if count >= 100:
# break
except tf.errors.OutOfRangeError:
# print(features)
print("异常")
pass
save_pkl(save_pkl_path, features_plk_list_all)
print("数据读取完毕: 数量:{}".format(len(features_plk_list_all)))
read_test(train_precomputed_file)