tfrecord加载数据与模型训练
参考代码:https://github.com/NLPLearn/QANet
1. 处理数据,将给定数据处理成输入数据的格式
# 1.处理数据,将给定的数据处理成输入数据的格式;
def process_file(file_name, data_type, training=True):
print("Generating {} examples...".format(data_type))
examples_count = 0
examples = []
eval_examples = {}
with open(file_name, 'r', encoding='utf-8') as fp:
data = json.load(fp)
example = ...
examples.append(example)
eval_examples[str(examples_count)] = example
random.shuffle(examples)
print("{} questions in total".format(len(examples)))
return examples, eval_examples
2.按照tfrecord的数据格式构建feature,并写入record文件
# 2. 按照tfrecord的数据格式构建feature,并写入record文件
def build_features(word2id, config, examples, data_type, out_file, is_test=False):
def create_int_feature(values):
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return feature
print("Processing {} examples...".format(data_type))
writer = tf.python_io.TFRecordWriter(out_file)
for example in examples:
_, context_question_id = context_question_to_id(...)
context_question_mask = [...]
context_question_segment = [...]
features = collections.OrderedDict()
features['input_ids'] = create_int_feature(context_question_id)
features['input_mask'] = create_int_feature(context_question_mask)
features['segment_ids'] = create_int_feature(context_question_segment)
features['id'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[example["id"]]))
if not is_test:
features['labels'] = create_int_feature(example['labels'])
# 每一条样本的特征,将一系列特征组织成一条样本
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
# 将每一条样本写入到tfrecord文件中
writer.write(tf_example.SerializeToString())
writer.close()
3.读取tfrecord文件中的数据
# 3. 读取tfrecord文件类型数据
def get_record_parser(config):
def parse(example):
context_ques_limit = config.context_limit + config.ques_limit + 3
dicts = {
'input_ids': tf.FixedLenFeature(shape=[context_ques_limit], dtype=tf.int64),
'input_mask': tf.FixedLenFeature(shape=[context_ques_limit], dtype=tf.int64),
'segment_ids': tf.FixedLenFeature(shape=[context_ques_limit], dtype=tf.int64),
'labels': tf.FixedLenFeature(shape=[3], dtype=tf.int64),
'id': tf.FixedLenFeature(shape=[1], dtype=tf.int64)
}
parsed_example = tf.parse_single_example(example, dicts)
input_ids = tf.cast(parsed_example['input_ids'], tf.int64)
input_mask = tf.cast(parsed_example['input_mask'], tf.int64)
segment_ids = tf.cast(parsed_example['segment_ids'], tf.int64)
labels = tf.cast(parsed_example['labels'], tf.int64)
qa_ids = tf.cast(parsed_example['id'], tf.int64)
return input_ids, input_mask, segment_ids, labels, qa_ids
return parse
def get_batch_dataset(record_file, parser, config):
num_threads = tf.constant(config.num_threads, dtype=tf.int32)
dataset = tf.data.TFRecordDataset(record_file).map(
parser, num_parallel_calls=num_threads).shuffle(config.capacity).repeat()
dataset = dataset.batch(config.batch_size)
return dataset
def get_dataset(record_file, parser, config):
num_threads = tf.constant(config.num_threads, dtype=tf.int32)
dataset = tf.data.TFRecordDataset(record_file).map(
parser, num_parallel_calls=num_threads).repeat().batch(config.batch_size)
return dataset
# 统计tfrecord中的样本总数
def total_sample(file_name):
sample_nums = 0
for record in tf.python_io.tf_record_iterator(file_name):
sample_nums += 1
return sample_nums
4. 定义模型
# 4.定义Model
class Model(object):
def __init__(self, config, batch, trainable=True, is_test=False):
self.config = config
self.dropout = tf.placeholder_with_default(0.1, (), name="dropout")
self.input_ids, self.input_mask, self.segment_ids, self.labels, self.qa_id = batch.get_next()
self.global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32,
initializer=tf.constant_initializer(0),
trainable=False)
self.build_graph()
self.train()
def _build_graph(self):
pass
def train(self):
self.loss = ...
self.opt = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.9, beta2=0.999, epsilon=1e-7)
# 可以对指定参数进行求导
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="RNNEncoder")
grads = self.opt.compute_gradients(self.loss, var_list=var_list)
gradients, variables = zip(*grads)
capped_grads, _ = tf.clip_by_global_norm(gradients, config.grad_clip)
self.train_op = self.opt.apply_gradients(zip(capped_grads, variables), global_step=self.global_step)
5. 数据加载与输入模型进行训练
# 5. 数据加载与输入模型进行训练
def train(config):
parser = get_record_parser(config)
train_dataset = get_batch_dataset(config.train_record_file, parser, config)
dev_dataset = get_dataset(config.dev_record_file, parser, config)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, train_dataset.output_types, train_dataset.output_shapes)
train_iterator = train_dataset.make_one_shot_iterator()
dev_iterator = dev_dataset.make_one_shot_iterator()
sess_config = tf.ConfigProto()
sess_config.gpu_options.allow_growth = True
sess = tf.Session(config=sess_config)
model = Model(config, iterator, trainable=True)
sess.run(tf.global_variables_initializer())
train_vars = tf.trainable_variables()
train_handle = sess.run(train_iterator.string_handle())
dev_handle = sess.run(dev_iterator.string_handle())
saver = tf.train.Saver
if os.path.exists(os.path.join(config.save_dir, "checkpoint")):
print("Loading last model from %s" % config.save_dir)
saver.restore(sess, tf.train.latest_checkpoint(config.save_dir))
else:
print("There is no saved checkpoint at %s. Creating model with fresh parameters." % config.save_dir)
global_step = max(sess.run(model.global_step), 1)
for _ in range(global_step, config.num_steps + 1):
global_step = sess.run(model.global_step) + 1
loss, train_op = sess.run([model.loss, model.train_op], feed_dict={
handle: train_handle, model.dropout: config.dropout})
if global_step % config.checkpoint == 0:
metric = evaluate(...)
filename = os.path.join(config.save_dir, "model_{}.ckpt".format(global_step))
saver.save(sess, filename)