Tensorflow之tfrecord加载数据与模型训练

tfrecord加载数据与模型训练

参考代码:https://github.com/NLPLearn/QANet

1. 处理数据,将给定数据处理成输入数据的格式
# 1.处理数据,将给定的数据处理成输入数据的格式;
def process_file(file_name, data_type, training=True):
    print("Generating {} examples...".format(data_type))
    examples_count = 0
    examples = []
    eval_examples = {}
    with open(file_name, 'r', encoding='utf-8') as fp:
        data = json.load(fp)
        example = ...
        examples.append(example)
        eval_examples[str(examples_count)] = example

    random.shuffle(examples)
    print("{} questions in total".format(len(examples)))
    return examples, eval_examples
2.按照tfrecord的数据格式构建feature,并写入record文件
# 2. 按照tfrecord的数据格式构建feature,并写入record文件
def build_features(word2id, config, examples, data_type, out_file, is_test=False):
    
    def create_int_feature(values):
        feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
        return feature
    print("Processing {} examples...".format(data_type))
    writer = tf.python_io.TFRecordWriter(out_file)

    for example in examples:
        _, context_question_id = context_question_to_id(...)

        context_question_mask = [...]
        context_question_segment = [...]

        features = collections.OrderedDict()
        features['input_ids'] = create_int_feature(context_question_id)
        features['input_mask'] = create_int_feature(context_question_mask)
        features['segment_ids'] = create_int_feature(context_question_segment)
        features['id'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[example["id"]]))

        if not is_test:
            features['labels'] = create_int_feature(example['labels'])

        # 每一条样本的特征,将一系列特征组织成一条样本
        tf_example = tf.train.Example(features=tf.train.Features(feature=features))
        # 将每一条样本写入到tfrecord文件中
        writer.write(tf_example.SerializeToString())

    writer.close()
3.读取tfrecord文件中的数据
# 3. 读取tfrecord文件类型数据
def get_record_parser(config):
    def parse(example):
        context_ques_limit = config.context_limit + config.ques_limit + 3

        dicts = {
            'input_ids': tf.FixedLenFeature(shape=[context_ques_limit], dtype=tf.int64),
            'input_mask': tf.FixedLenFeature(shape=[context_ques_limit], dtype=tf.int64),
            'segment_ids': tf.FixedLenFeature(shape=[context_ques_limit], dtype=tf.int64),
            'labels': tf.FixedLenFeature(shape=[3], dtype=tf.int64),
            'id': tf.FixedLenFeature(shape=[1], dtype=tf.int64)
        }

        parsed_example = tf.parse_single_example(example, dicts)

        input_ids = tf.cast(parsed_example['input_ids'], tf.int64)
        input_mask = tf.cast(parsed_example['input_mask'], tf.int64)
        segment_ids = tf.cast(parsed_example['segment_ids'], tf.int64)
        labels = tf.cast(parsed_example['labels'], tf.int64)
        qa_ids = tf.cast(parsed_example['id'], tf.int64)

        return input_ids, input_mask, segment_ids, labels, qa_ids

    return parse

def get_batch_dataset(record_file, parser, config):
    num_threads = tf.constant(config.num_threads, dtype=tf.int32)
    dataset = tf.data.TFRecordDataset(record_file).map(
        parser, num_parallel_calls=num_threads).shuffle(config.capacity).repeat()
    dataset = dataset.batch(config.batch_size)

    return dataset

def get_dataset(record_file, parser, config):
    num_threads = tf.constant(config.num_threads, dtype=tf.int32)
    dataset = tf.data.TFRecordDataset(record_file).map(
        parser, num_parallel_calls=num_threads).repeat().batch(config.batch_size)
    return dataset

# 统计tfrecord中的样本总数
def total_sample(file_name):
    sample_nums = 0
    for record in tf.python_io.tf_record_iterator(file_name):
        sample_nums += 1
        
    return sample_nums
4. 定义模型
# 4.定义Model
class Model(object):
    def __init__(self, config, batch, trainable=True, is_test=False):
        self.config = config
        self.dropout = tf.placeholder_with_default(0.1, (), name="dropout")
        self.input_ids, self.input_mask, self.segment_ids, self.labels, self.qa_id = batch.get_next()
        
        self.global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32,
                                           initializer=tf.constant_initializer(0),
                                           trainable=False)
        self.build_graph()
        self.train()
        
    def _build_graph(self):
        pass
        
    def train(self):
        self.loss = ...
        self.opt = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.9, beta2=0.999, epsilon=1e-7)
        # 可以对指定参数进行求导
        var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="RNNEncoder")

        grads = self.opt.compute_gradients(self.loss, var_list=var_list)
        gradients, variables = zip(*grads)
        capped_grads, _ = tf.clip_by_global_norm(gradients, config.grad_clip)
        self.train_op = self.opt.apply_gradients(zip(capped_grads, variables), global_step=self.global_step)
5. 数据加载与输入模型进行训练
# 5. 数据加载与输入模型进行训练
def train(config):
    parser = get_record_parser(config)

    train_dataset = get_batch_dataset(config.train_record_file, parser, config)
    dev_dataset = get_dataset(config.dev_record_file, parser, config)
    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, train_dataset.output_types, train_dataset.output_shapes)
    train_iterator = train_dataset.make_one_shot_iterator()
    dev_iterator = dev_dataset.make_one_shot_iterator()
    
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess = tf.Session(config=sess_config)

    model = Model(config, iterator, trainable=True)
    sess.run(tf.global_variables_initializer())
    train_vars = tf.trainable_variables()
    
    train_handle = sess.run(train_iterator.string_handle())
    dev_handle = sess.run(dev_iterator.string_handle())

    saver = tf.train.Saver
    
    if os.path.exists(os.path.join(config.save_dir, "checkpoint")):
        print("Loading last model from %s" % config.save_dir)
        saver.restore(sess, tf.train.latest_checkpoint(config.save_dir))
    else:
        print("There is no saved checkpoint at %s. Creating model with fresh parameters." % config.save_dir)

    global_step = max(sess.run(model.global_step), 1)
    for _ in range(global_step, config.num_steps + 1):
        global_step = sess.run(model.global_step) + 1
        loss, train_op = sess.run([model.loss, model.train_op], feed_dict={
            handle: train_handle, model.dropout: config.dropout})
        
        if global_step % config.checkpoint == 0:
            metric = evaluate(...)
            
        filename = os.path.join(config.save_dir, "model_{}.ckpt".format(global_step))
        saver.save(sess, filename)
  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值