用slim加入BN L2的参考

具体参看这篇博客:https://blog.csdn.net/jiruiYang/article/details/77202674
说的不错,而且这份githun代码值得借鉴:https://github.com/soloice/mnist-bn

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys
import os

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.python.ops import control_flow_ops

FLAGS = None


def model():
    # Create the model
    x = tf.placeholder(tf.float32, [None, 784])
    keep_prob = tf.placeholder(tf.float32, [])
    y_ = tf.placeholder(tf.float32, [None, 10])
    is_training = tf.placeholder(tf.bool, [])
    x_image = tf.reshape(x, [-1, 28, 28, 1])
    with slim.arg_scope([slim.conv2d, slim.fully_connected],
                        activation_fn=tf.nn.crelu,
                        normalizer_fn=slim.batch_norm,
                        normalizer_params={'is_training': is_training, 'decay': 0.95}):
        conv1 = slim.conv2d(x_image, 16, [5, 5], scope='conv1')
        pool1 = slim.max_pool2d(conv1, [2, 2], scope='pool1')
        conv2 = slim.conv2d(pool1, 32, [5, 5], scope='conv2')
        pool2 = slim.max_pool2d(conv2, [2, 2], scope='pool2')
        flatten = slim.flatten(pool2)
        fc = slim.fully_connected(flatten, 1024, scope='fc1')
        drop = slim.dropout(fc, keep_prob=keep_prob)
        logits = slim.fully_connected(drop, 10, activation_fn=None, scope='logits')

    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    cross_entropy = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=logits))

    step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
    train_step = slim.learning.create_train_op(cross_entropy, optimizer, global_step=step)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    if update_ops:
        updates = tf.group(*update_ops)
        cross_entropy = control_flow_ops.with_dependencies([updates], cross_entropy)

    # Add summaries for BN variables
    tf.summary.scalar('accuracy', accuracy)
    tf.summary.scalar('cross_entropy', cross_entropy)
    for v in tf.all_variables():
        if v.name.startswith('conv1/Batch') or v.name.startswith('conv2/Batch') or \
                v.name.startswith('fc1/Batch') or v.name.startswith('logits/Batch'):
            print(v.name)
            tf.summary.histogram(v.name, v)
    merged_summary_op = tf.summary.merge_all()

    return {'x': x,
            'y_': y_,
            'keep_prob': keep_prob,
            'is_training': is_training,
            'train_step': train_step,
            'global_step': step,
            'accuracy': accuracy,
            'cross_entropy': cross_entropy,
            'summary': merged_summary_op}


def train():
    # clear checkpoint directory
    print('Clearing existed checkpoints and logs')
    for root, sub_folder, file_list in os.walk(FLAGS.checkpoint_dir):
        for f in file_list:
            os.remove(os.path.join(root, f))
    for root, sub_folder, file_list in os.walk(FLAGS.train_log_dir):
        for f in file_list:
            os.remove(os.path.join(root, f))

    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
    net = model()
    sess = tf.Session()
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    train_writer = tf.summary.FileWriter(os.path.join(FLAGS.train_log_dir, 'train'), sess.graph)
    valid_writer = tf.summary.FileWriter(os.path.join(FLAGS.train_log_dir, 'valid'), sess.graph)

    # Train
    batch_size = FLAGS.batch_size
    for i in range(10001):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        train_dict = {net['x']: batch_xs,
                      net['y_']: batch_ys,
                      net['keep_prob']: 0.5,
                      net['is_training']: True}
        step, _ = sess.run([net['global_step'], net['train_step']], feed_dict=train_dict)
        if step % 50 == 0:
            train_dict = {net['x']: batch_xs,
                          net['y_']: batch_ys,
                          net['keep_prob']: 1.0,
                          net['is_training']: True}
            entropy, acc, summary = sess.run([net['cross_entropy'], net['accuracy'], net['summary']],
                                             feed_dict=train_dict)
            train_writer.add_summary(summary, global_step=step)
            print('Train step {}: entropy {}: accuracy {}'.format(step, entropy, acc))

            # Note: the validation error is erratic in the beginning (Maybe 2~3k steps).
            # This does NOT imply the batch normalization is buggy.
            # On the contrary, it's BN's dynamics: moving_mean/variance are not estimated that well in the beginning.
            valid_dict = {net['x']: batch_xs,
                          net['y_']: batch_ys,
                          net['keep_prob']: 1.0,
                          net['is_training']: False}
            entropy, acc, summary = sess.run([net['cross_entropy'], net['accuracy'], net['summary']],
                                             feed_dict=valid_dict)
            valid_writer.add_summary(summary, global_step=step)
            print('***** Valid step {}: entropy {}: accuracy {} *****'.format(step, entropy, acc))
    saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'mnist-conv-slim'))
    print('Finish training')

    # validation
    acc = 0.0
    batch_size = FLAGS.batch_size
    num_iter = 5000 // batch_size
    for i in range(num_iter):
        batch_xs, batch_ys = mnist.validation.next_batch(batch_size)
        test_dict = {net['x']: batch_xs,
                     net['y_']: batch_ys,
                     net['keep_prob']: 1.0,
                     net['is_training']: False}
        acc_ = sess.run(net['accuracy'], feed_dict=test_dict)
        acc += acc_
    print('Overall validation accuracy {}'.format(acc / num_iter))
    sess.close()


def test():
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
    # Test trained model
    net = model()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    if ckpt:
        saver.restore(sess, ckpt)
        print("restore from the checkpoint {0}".format(ckpt))

    acc = 0.0
    batch_size = FLAGS.batch_size
    num_iter = 10000 // batch_size
    for i in range(num_iter):
        batch_xs, batch_ys = mnist.test.next_batch(batch_size)
        feed_dict = {net['x']: batch_xs,
                     net['y_']: batch_ys,
                     net['keep_prob']: 1.0,
                     net['is_training']: False}
        acc_ = sess.run(net['accuracy'], feed_dict=feed_dict)
        acc += acc_
    print('Overall test accuracy {}'.format(acc / num_iter))
    sess.close()


def main(_):
    if FLAGS.phase == 'train':
        train()
    else:
        test()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='MNIST_data',
                        help='Directory for storing input data')
    parser.add_argument('--phase', type=str, default='train',
                        help='Training or test phase, should be one of {"train", "test"}')
    parser.add_argument('--batch_size', type=int, default=50,
                        help='Training or test phase, should be one of {"train", "test"}')
    parser.add_argument('--train_log_dir', type=str, default='log',
                        help='Directory for logs')
    parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
                        help='Directory for checkpoint file')
    FLAGS, unparsed = parser.parse_known_args()
    if not os.path.isdir(FLAGS.checkpoint_dir):
        os.mkdir(FLAGS.checkpoint_dir)
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
        <link rel="stylesheet" href="https://csdnimg.cn/release/phoenix/template/css/markdown_views-ea0013b516.css">
            </div>
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值