一、fully_connected_feed.py
1.主函数
parser = argparse.ArgumentParser()
parser.add_argument(
'--learning_rate',
type=float,
default=0.01,
help='Initial learning rate.'
)
...
FLAGS, unparsed = parser.parse_known_args()
...
run_training()
主函数主要是通过argparse来设定learning_rate, max_steps, input_data_dir, hidden1, hidden2, batch_size等参数保存在FLAGS中,然后调用run_trainning进行训练。
2.run_training()函数
# Build a Graph that computes predictions from the inference model.
logits = mnist.inference(images_placeholder,
FLAGS.hidden1,
FLAGS.hidden2)
# Add to the Graph the Ops for loss calculation.
loss = mnist.loss(logits, labels_placeholder)
# Add to the Graph the Ops that calculate and apply gradients.
train_op = mnist.training(loss, FLAGS.learning_rate)
# Add the Op to compare the logits to the labels during evaluation.
eval_correct = mnist.evaluation(logits, labels_placeholder)
...
# Start the training loop.
for step in xrange(FLAGS.max_steps):
start_time = time.time()
# Fill a feed dictionary with the actual set of images and labels
# for this particular training step.
feed_dict = fill_feed_dict(data_sets.train,
images_placeholder,
labels_placeholder)
# Run one step of the model. The return values are the activations
# from the `train_op` (which is discarded) and the `loss` Op. To
# inspect the values of your Ops or variables, you may include them
# in the list passed to sess.run() and the value tensors will be
# returned in the tuple from the call.
_, loss_value = sess.run([train_op, loss],
feed_dict=feed_dict)
duration = time.time() - start_time
...
# Save a checkpoint and evaluate the model periodically.
if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
saver.save(sess, checkpoint_file, global_step=step)
# Evaluate against the training set.
print('Training Data Eval:')
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.train)
# Evaluate against the validation set.
print('Validation Data Eval:')
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.validation)
# Evaluate against the test set.
print('Test Data Eval:')
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.test)
run_training函数首先将logits, loss, train_op, eval_correct添加到计算图中,调用sess.run([train_op, loss], feed_dict=feed_dict)进行训练,调用do_eval函数进行train, validation, test评估。
3.do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_set)函数
for step in xrange(steps_per_epoch):
feed_dict = fill_feed_dict(data_set,
images_placeholder,
labels_placeholder)
true_count += sess.run(eval_correct, feed_dict=feed_dict)
precision = float(true_count) / num_examples
do_eval函数计算feed_dict,调用sess.run(eval_correct, feed_dict=feed_dict)来计算预测精度。
4.mnist.py
def inference(images, hidden1_units, hidden2_units):
return logits
def loss(logits, labels):
return loss
def training(loss, learning_rate):
return train_op
def evaluation(logits, labels):
return tf.reduce_sum(tf.cast(correct, tf.int32))
该程序中inference函数构建模型,计算logits;loss函数通过logits和labels来计算loss;training函数定义了模型的学习优化算法,返回训练操作算子train_op;evaluation函数计算预测正确样本的个数。
二、resnet.py
1.main函数
关键代码如下:
mnist = tf.contrib.learn.datasets.DATASETS['mnist']('/tmp/mnist')
classifer = tf.estimator.Estimator(model_fn=res_net_model)
train_input_fn = tf.estimator.inputs.numpy_input_fn(x={X_FEATURE: mnist.train.images},...)
classifer.train(input_fn=train_input_fn,steps = 100)
sorces = classifer.evaluate(input_fn=test_input_fn)
其中mnist为数据,classifer为估计器(可以train,也可以evaluate),classifer初始化时需要输入res_net_model。x={X_FEATURE: mnist.train.images}起到占位符的作用。
2.res_net_model(feature,labels,mode)函数
关键代码如下:
1)resnet各个blockneck层的配置,用namedtuple来配置各层参数
BottleneckGroup = namedtuple('BottleneckGroup',
['num_blocks', 'num_filters', 'bottleneck_size'])
groups = [
BottleneckGroup(3, 128, 32), BottleneckGroup(3, 256, 64),
BottleneckGroup(3, 512, 128), BottleneckGroup(3, 1024, 256)
]
共4层,每层3个blockneck。
2)blockneck层的构成
with tf.variable_scope(name + '/conv_in'):
conv = tf.layers.conv2d(
net,
filters=group.num_filters,
kernel_size=1,
padding='valid',
activation=tf.nn.relu)
conv = tf.layers.batch_normalization(conv)
...
net = net+conv
tf.variable_scope用于给该层conv的参数范围标识(估计用于模型计算图的构建)。
tf.layers.conv2d用于构建2D的cnn。
net = conv+net这代码是够厉害的,模型竟然能够直接相加!!!!
3)模型的输出结果
logits = tf.layers.dense(net, N_DIGITS, activation=None)
predicted_classes = tf.argmax(logits, 1)
#预测
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = {
'class': predicted_classes,
'prob': tf.nn.softmax(logits)
}
return tf.estimator.EstimatorSpec(mode, predictions=predictions)
#训练
onehot_labels = tf.one_hot(tf.cast(labels, tf.int32), N_DIGITS, 1, 0)
loss = tf.losses.softmax_cross_entropy(
onehot_labels=onehot_labels, logits=logits)
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.AdagradOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
#评估
eval_metric_ops = {
'accuracy': tf.metrics.accuracy(
labels=labels, predictions=predicted_classes)
return tf.estimator.EstimatorSpec(
mode, loss=loss, eval_metric_ops=eval_metric_ops)
可以看出根据不同的mode,模型会返回不同的EstimatorSpec。
EstimatorSpec可接受的参数包括:predictions, train_op, eval_metric_op。
感觉由于采用了tf.estimator和tf.layers.conv2d,resnet.py模型部分代码比fully_connected_feed.py封装程度更高!!!