TensorFlow学习实践(二):使用TFRecord格式数据和tf.data API进行模型训练和预测

本文以mnist为例,介绍如何使用TFRecord格式数据和tf.data API进行模型训练和预测。

参考:

1、cifar10

2、https://tensorflow.google.cn/guide/datasets

目录

一、数据解析

二、定义模型、损失和训练操作

三、模型训练


一、数据解析

def parse_data(example_proto):
    features = {'img_raw': tf.FixedLenFeature([], tf.string, ''),
                'label': tf.FixedLenFeature([], tf.int64, 0)}
    parsed_features = tf.parse_single_example(example_proto, features)
    image = tf.decode_raw(parsed_features['img_raw'], tf.uint8)
    label = tf.cast(parsed_features['label'], tf.int64)
    image = tf.reshape(image, [FLAGS.image_height, FLAGS.image_width, 1])
    image = tf.cast(image, tf.float32)
    return image, label

该函数会对tf.data.TFRecordDataset的每一个元素进行处理,后面会用到。

二、定义模型、损失和训练操作

参见上一篇文章:TensorFlow学习实践(一):使用TFRecord格式数据和队列进行模型训练和预测

三、模型训练

def train():
    filenames = tf.placeholder(tf.string, [None])
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(mnist.parse_data)
    dataset = dataset.shuffle(buffer_size=50000)
    dataset = dataset.batch(FLAGS.batch_size)
    dataset = dataset.repeat()

    iterator = dataset.make_initializable_iterator()

    global_step = tf.train.get_or_create_global_step()
    images, labels = iterator.get_next()
    logits, pred = mnist.inference(images, training=True)
    loss = mnist.loss(logits, labels)
    train_op = mnist.train(loss, global_step)

    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_step), tf.train.NanTensorHook(loss)],
        save_checkpoint_steps=100
    ) as mon_sess:
        mon_sess.run(iterator.initializer, feed_dict={filenames: ['train_img.tfrecords']})
        while not mon_sess.should_stop():
            _, train_loss, train_step, label = mon_sess.run([train_op, loss, global_step, labels])
            if train_step % 100 == 0:
                print('step: {}, loss: {}'.format(train_step, train_loss))

此处使用make_initializable_iterator()接口来读取批量数据,也可以使用make_one_shot_iterator(),官方文档有详细介绍,见文章开头。强烈建议有不明白的首先查看官方文档,Google的文档写的很好。这里用了tf.train.MonitoredTrainingSession,直接用tf.Session也可以。

还有个问题,我们一般在训练的过程中,习惯每训练一定步数对验证集进行验证,这时不能用make_initializable_iterator(),代码如下:

def train_and_validation():
    training_dataset = tf.data.TFRecordDataset(['./train_img.tfrecords'])
    validation_dataset = tf.data.TFRecordDataset(['./validation_img.tfrecords'])
    test_dataset = tf.data.TFRecordDataset(['./test_img.tfrecords'])

    training_dataset = training_dataset.map(mnist.parse_data)
    training_dataset = training_dataset.shuffle(50000).batch(FLAGS.batch_size).repeat()
    validation_dataset = validation_dataset.map(mnist.parse_data).batch(FLAGS.batch_size)
    test_dataset = test_dataset.map(mnist.parse_data).batch(FLAGS.batch_size)

    iterator = tf.data.Iterator.from_structure(output_types=training_dataset.output_types,
                                               output_shapes=training_dataset.output_shapes)

    training_init_op = iterator.make_initializer(training_dataset)
    validation_init_op = iterator.make_initializer(validation_dataset)
    test_init_op = iterator.make_initializer(test_dataset)
    images, labels = iterator.get_next()

    training = tf.placeholder(dtype=tf.bool)
    logits, pred = mnist.inference(images, training=training)
    loss = mnist.loss(logits, labels)
    top_k_op = tf.nn.in_top_k(logits, labels, 1)
    global_step = tf.train.get_or_create_global_step()
    train_op = mnist.train(loss, global_step)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(training_init_op)
        print('begin to train!')
        ckpt = os.path.join(FLAGS.train_dir, 'model.ckpt')
        train_step = 0
        while train_step < FLAGS.max_step:
            _, train_loss, step, label = sess.run([train_op, loss, global_step, labels], feed_dict={training: True})
            train_step += 1
            if train_step % 100 == 0:
                saver.save(sess, ckpt, train_step)
                if train_step % 1000 == 0:
                    precision = evaluate(sess, top_k_op, training, mnist.TRAIN_EXAMPLES_NUM)
                    print('step: {}, loss: {}, training precision: {}'.format(train_step, train_loss, precision))
                sess.run(validation_init_op)
                precision = evaluate(sess, top_k_op, training, mnist.VALIDATION_EXAMPLES_NUM)
                print('step: {}, loss: {}, validation precision: {}'.format(train_step, train_loss, precision))
                sess.run(training_init_op)
        sess.run(test_init_op)
        precision = evaluate(sess, top_k_op, training, mnist.TEST_EXAMPLES_NUM)
        print('finally test precision: {}'.format(precision))

此处关键是,使用可重复初始化的iterator,运行结果:

begin to train!
step: 100, loss: 0.23804739117622375, validation precision: 0.9572
...
step: 1000, loss: 0.0894625335931778, training precision: 0.9841272727272727
step: 1000, loss: 0.0894625335931778, validation precision: 0.9816
...
step: 2000, loss: 0.009035548195242882, training precision: 0.9938
step: 2000, loss: 0.009035548195242882, validation precision: 0.9858
...
step: 3000, loss: 0.013184064999222755, training precision: 0.9929636363636364
step: 3000, loss: 0.013184064999222755, validation precision: 0.988
...
step: 4000, loss: 0.008312588557600975, training precision: 0.9950545454545454
step: 4000, loss: 0.008312588557600975, validation precision: 0.9876
step: 4100, loss: 0.03157630190253258, validation precision: 0.9868
step: 4200, loss: 0.06517153978347778, validation precision: 0.987
step: 4300, loss: 0.03605052828788757, validation precision: 0.9886
step: 4400, loss: 0.01293920911848545, validation precision: 0.991
step: 4500, loss: 0.0002804732066579163, validation precision: 0.991
step: 4600, loss: 0.0033769337460398674, validation precision: 0.9918
step: 4700, loss: 0.00031401298474520445, validation precision: 0.9914
step: 4800, loss: 0.0015675770118832588, validation precision: 0.992
step: 4900, loss: 0.02173098735511303, validation precision: 0.9924
step: 5000, loss: 0.002222904935479164, training precision: 0.9990181818181818
step: 5000, loss: 0.002222904935479164, validation precision: 0.9924
...
step: 6000, loss: 0.05456211417913437, training precision: 0.9997272727272727
step: 6000, loss: 0.05456211417913437, validation precision: 0.9922
...
step: 7000, loss: 0.0008142999722622335, training precision: 0.9999272727272728
step: 7000, loss: 0.0008142999722622335, validation precision: 0.993
...
step: 8000, loss: 0.00024679378839209676, training precision: 0.9999272727272728
step: 8000, loss: 0.00024679378839209676, validation precision: 0.9926
...
step: 9000, loss: 0.0018379478715360165, training precision: 1.0
step: 9000, loss: 0.0018379478715360165, validation precision: 0.9928
...
step: 10000, loss: 2.8087904411222553e-06, training precision: 1.0
step: 10000, loss: 2.8087904411222553e-06, validation precision: 0.993
finally test precision: 0.9917

Process finished with exit code 0

训练了1万个step,batch是128,训练过程采用了学习率衰减,每10个epoch衰减为上次的1/10,10个epoch对应4296个step。我做了实验,如果不用学习率衰减,验证集准确率最终停在98.5左右,到2000步时就不怎么变了,上面采用了学习率衰减,可以从结果中看到,过了4200步,准确率进一步提高,至99.2左右,可见学习率衰减还是很有作用的。

最后:完整代码

https://github.com/buptlj/learn_tf

 

 

 

 

 

 

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值