具体参看这篇博客:https://blog.csdn.net/jiruiYang/article/details/77202674
说的不错,而且这份githun代码值得借鉴:https://github.com/soloice/mnist-bn
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import os
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.python.ops import control_flow_ops
FLAGS = None
def model():
# Create the model
x = tf.placeholder(tf.float32, [None, 784])
keep_prob = tf.placeholder(tf.float32, [])
y_ = tf.placeholder(tf.float32, [None, 10])
is_training = tf.placeholder(tf.bool, [])
x_image = tf.reshape(x, [-1, 28, 28, 1])
with slim.arg_scope([slim.conv2d, slim.fully_connected],
activation_fn=tf.nn.crelu,
normalizer_fn=slim.batch_norm,
normalizer_params={'is_training': is_training, 'decay': 0.95}):
conv1 = slim.conv2d(x_image, 16, [5, 5], scope='conv1')
pool1 = slim.max_pool2d(conv1, [2, 2], scope='pool1')
conv2 = slim.conv2d(pool1, 32, [5, 5], scope='conv2')
pool2 = slim.max_pool2d(conv2, [2, 2], scope='pool2')
flatten = slim.flatten(pool2)
fc = slim.fully_connected(flatten, 1024, scope='fc1')
drop = slim.dropout(fc, keep_prob=keep_prob)
logits = slim.fully_connected(drop, 10, activation_fn=None, scope='logits')
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=logits))
step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
train_step = slim.learning.create_train_op(cross_entropy, optimizer, global_step=step)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if update_ops:
updates = tf.group(*update_ops)
cross_entropy = control_flow_ops.with_dependencies([updates], cross_entropy)
# Add summaries for BN variables
tf.summary.scalar('accuracy', accuracy)
tf.summary.scalar('cross_entropy', cross_entropy)
for v in tf.all_variables():
if v.name.startswith('conv1/Batch') or v.name.startswith('conv2/Batch') or \
v.name.startswith('fc1/Batch') or v.name.startswith('logits/Batch'):
print(v.name)
tf.summary.histogram(v.name, v)
merged_summary_op = tf.summary.merge_all()
return {'x': x,
'y_': y_,
'keep_prob': keep_prob,
'is_training': is_training,
'train_step': train_step,
'global_step': step,
'accuracy': accuracy,
'cross_entropy': cross_entropy,
'summary': merged_summary_op}
def train():
# clear checkpoint directory
print('Clearing existed checkpoints and logs')
for root, sub_folder, file_list in os.walk(FLAGS.checkpoint_dir):
for f in file_list:
os.remove(os.path.join(root, f))
for root, sub_folder, file_list in os.walk(FLAGS.train_log_dir):
for f in file_list:
os.remove(os.path.join(root, f))
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
net = model()
sess = tf.Session()
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
train_writer = tf.summary.FileWriter(os.path.join(FLAGS.train_log_dir, 'train'), sess.graph)
valid_writer = tf.summary.FileWriter(os.path.join(FLAGS.train_log_dir, 'valid'), sess.graph)
# Train
batch_size = FLAGS.batch_size
for i in range(10001):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
train_dict = {net['x']: batch_xs,
net['y_']: batch_ys,
net['keep_prob']: 0.5,
net['is_training']: True}
step, _ = sess.run([net['global_step'], net['train_step']], feed_dict=train_dict)
if step % 50 == 0:
train_dict = {net['x']: batch_xs,
net['y_']: batch_ys,
net['keep_prob']: 1.0,
net['is_training']: True}
entropy, acc, summary = sess.run([net['cross_entropy'], net['accuracy'], net['summary']],
feed_dict=train_dict)
train_writer.add_summary(summary, global_step=step)
print('Train step {}: entropy {}: accuracy {}'.format(step, entropy, acc))
# Note: the validation error is erratic in the beginning (Maybe 2~3k steps).
# This does NOT imply the batch normalization is buggy.
# On the contrary, it's BN's dynamics: moving_mean/variance are not estimated that well in the beginning.
valid_dict = {net['x']: batch_xs,
net['y_']: batch_ys,
net['keep_prob']: 1.0,
net['is_training']: False}
entropy, acc, summary = sess.run([net['cross_entropy'], net['accuracy'], net['summary']],
feed_dict=valid_dict)
valid_writer.add_summary(summary, global_step=step)
print('***** Valid step {}: entropy {}: accuracy {} *****'.format(step, entropy, acc))
saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'mnist-conv-slim'))
print('Finish training')
# validation
acc = 0.0
batch_size = FLAGS.batch_size
num_iter = 5000 // batch_size
for i in range(num_iter):
batch_xs, batch_ys = mnist.validation.next_batch(batch_size)
test_dict = {net['x']: batch_xs,
net['y_']: batch_ys,
net['keep_prob']: 1.0,
net['is_training']: False}
acc_ = sess.run(net['accuracy'], feed_dict=test_dict)
acc += acc_
print('Overall validation accuracy {}'.format(acc / num_iter))
sess.close()
def test():
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
# Test trained model
net = model()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if ckpt:
saver.restore(sess, ckpt)
print("restore from the checkpoint {0}".format(ckpt))
acc = 0.0
batch_size = FLAGS.batch_size
num_iter = 10000 // batch_size
for i in range(num_iter):
batch_xs, batch_ys = mnist.test.next_batch(batch_size)
feed_dict = {net['x']: batch_xs,
net['y_']: batch_ys,
net['keep_prob']: 1.0,
net['is_training']: False}
acc_ = sess.run(net['accuracy'], feed_dict=feed_dict)
acc += acc_
print('Overall test accuracy {}'.format(acc / num_iter))
sess.close()
def main(_):
if FLAGS.phase == 'train':
train()
else:
test()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='MNIST_data',
help='Directory for storing input data')
parser.add_argument('--phase', type=str, default='train',
help='Training or test phase, should be one of {"train", "test"}')
parser.add_argument('--batch_size', type=int, default=50,
help='Training or test phase, should be one of {"train", "test"}')
parser.add_argument('--train_log_dir', type=str, default='log',
help='Directory for logs')
parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
help='Directory for checkpoint file')
FLAGS, unparsed = parser.parse_known_args()
if not os.path.isdir(FLAGS.checkpoint_dir):
os.mkdir(FLAGS.checkpoint_dir)
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117
- 118
- 119
- 120
- 121
- 122
- 123
- 124
- 125
- 126
- 127
- 128
- 129
- 130
- 131
- 132
- 133
- 134
- 135
- 136
- 137
- 138
- 139
- 140
- 141
- 142
- 143
- 144
- 145
- 146
- 147
- 148
- 149
- 150
- 151
- 152
- 153
- 154
- 155
- 156
- 157
- 158
- 159
- 160
- 161
- 162
- 163
- 164
- 165
- 166
- 167
- 168
- 169
- 170
- 171
- 172
- 173
- 174
- 175
- 176
- 177
- 178
- 179
- 180
- 181
- 182
- 183
- 184
- 185
- 186
- 187
- 188
<link rel="stylesheet" href="https://csdnimg.cn/release/phoenix/template/css/markdown_views-ea0013b516.css">
</div>