Tensorflow代码阅读

一、fully_connected_feed.py

1.主函数

parser = argparse.ArgumentParser()
parser.add_argument(
    '--learning_rate',
    type=float,
    default=0.01,
    help='Initial learning rate.'
)
...
FLAGS, unparsed = parser.parse_known_args()
...
run_training()

主函数主要是通过argparse来设定learning_rate, max_steps, input_data_dir, hidden1, hidden2, batch_size等参数保存在FLAGS中,然后调用run_trainning进行训练。
2.run_training()函数

# Build a Graph that computes predictions from the inference model.
logits = mnist.inference(images_placeholder,
                         FLAGS.hidden1,
                         FLAGS.hidden2)
# Add to the Graph the Ops for loss calculation.
loss = mnist.loss(logits, labels_placeholder)
# Add to the Graph the Ops that calculate and apply gradients.
train_op = mnist.training(loss, FLAGS.learning_rate)
# Add the Op to compare the logits to the labels during evaluation.
eval_correct = mnist.evaluation(logits, labels_placeholder)
...
# Start the training loop.
for step in xrange(FLAGS.max_steps):
  start_time = time.time()
  # Fill a feed dictionary with the actual set of images and labels
  # for this particular training step.
  feed_dict = fill_feed_dict(data_sets.train,
                             images_placeholder,
                             labels_placeholder)
  # Run one step of the model.  The return values are the activations
  # from the `train_op` (which is discarded) and the `loss` Op.  To
  # inspect the values of your Ops or variables, you may include them
  # in the list passed to sess.run() and the value tensors will be
  # returned in the tuple from the call.
  _, loss_value = sess.run([train_op, loss],
                           feed_dict=feed_dict)
  duration = time.time() - start_time
...
 # Save a checkpoint and evaluate the model periodically.
 if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
   checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
   saver.save(sess, checkpoint_file, global_step=step)
   # Evaluate against the training set.
   print('Training Data Eval:')
   do_eval(sess,
           eval_correct,
           images_placeholder,
           labels_placeholder,
           data_sets.train)
   # Evaluate against the validation set.
   print('Validation Data Eval:')
   do_eval(sess,
           eval_correct,
           images_placeholder,
           labels_placeholder,
           data_sets.validation)
   # Evaluate against the test set.
   print('Test Data Eval:')
   do_eval(sess,
           eval_correct,
           images_placeholder,
           labels_placeholder,
           data_sets.test)

run_training函数首先将logits, loss, train_op, eval_correct添加到计算图中,调用sess.run([train_op, loss], feed_dict=feed_dict)进行训练,调用do_eval函数进行train, validation, test评估。
3.do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_set)函数

for step in xrange(steps_per_epoch):
  feed_dict = fill_feed_dict(data_set,
                             images_placeholder,
                             labels_placeholder)
  true_count += sess.run(eval_correct, feed_dict=feed_dict)
precision = float(true_count) / num_examples

do_eval函数计算feed_dict,调用sess.run(eval_correct, feed_dict=feed_dict)来计算预测精度。
4.mnist.py

def inference(images, hidden1_units, hidden2_units):
    return logits
def loss(logits, labels):
    return loss
def training(loss, learning_rate):
    return train_op
def evaluation(logits, labels):
    return tf.reduce_sum(tf.cast(correct, tf.int32))

该程序中inference函数构建模型,计算logits;loss函数通过logits和labels来计算loss;training函数定义了模型的学习优化算法,返回训练操作算子train_op;evaluation函数计算预测正确样本的个数。

二、resnet.py

1.main函数
关键代码如下:

mnist = tf.contrib.learn.datasets.DATASETS['mnist']('/tmp/mnist')
classifer = tf.estimator.Estimator(model_fn=res_net_model)
train_input_fn = tf.estimator.inputs.numpy_input_fn(x={X_FEATURE: mnist.train.images},...)
classifer.train(input_fn=train_input_fn,steps = 100)
sorces = classifer.evaluate(input_fn=test_input_fn)

其中mnist为数据,classifer为估计器(可以train,也可以evaluate),classifer初始化时需要输入res_net_model。x={X_FEATURE: mnist.train.images}起到占位符的作用。

2.res_net_model(feature,labels,mode)函数
关键代码如下:
1)resnet各个blockneck层的配置,用namedtuple来配置各层参数

  BottleneckGroup = namedtuple('BottleneckGroup',
                               ['num_blocks', 'num_filters', 'bottleneck_size'])
  groups = [
      BottleneckGroup(3, 128, 32), BottleneckGroup(3, 256, 64),
      BottleneckGroup(3, 512, 128), BottleneckGroup(3, 1024, 256)
  ]

共4层,每层3个blockneck。
2)blockneck层的构成

with tf.variable_scope(name + '/conv_in'):
  conv = tf.layers.conv2d(
      net,
      filters=group.num_filters,
      kernel_size=1,
      padding='valid',
      activation=tf.nn.relu)
  conv = tf.layers.batch_normalization(conv)
...
net = net+conv

tf.variable_scope用于给该层conv的参数范围标识(估计用于模型计算图的构建)。
tf.layers.conv2d用于构建2D的cnn。
net = conv+net这代码是够厉害的,模型竟然能够直接相加!!!!
3)模型的输出结果

logits = tf.layers.dense(net, N_DIGITS, activation=None)
predicted_classes = tf.argmax(logits, 1)
#预测
if mode == tf.estimator.ModeKeys.PREDICT:
  predictions = {
      'class': predicted_classes,
      'prob': tf.nn.softmax(logits)
  }
  return tf.estimator.EstimatorSpec(mode, predictions=predictions)
#训练
onehot_labels = tf.one_hot(tf.cast(labels, tf.int32), N_DIGITS, 1, 0)
loss = tf.losses.softmax_cross_entropy(
    onehot_labels=onehot_labels, logits=logits)
if mode == tf.estimator.ModeKeys.TRAIN:
  optimizer = tf.train.AdagradOptimizer(learning_rate=0.01)
  train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
  return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
#评估
eval_metric_ops = {
    'accuracy': tf.metrics.accuracy(
        labels=labels, predictions=predicted_classes)
return tf.estimator.EstimatorSpec(
    mode, loss=loss, eval_metric_ops=eval_metric_ops)

可以看出根据不同的mode,模型会返回不同的EstimatorSpec。
EstimatorSpec可接受的参数包括:predictions, train_op, eval_metric_op。
感觉由于采用了tf.estimator和tf.layers.conv2d,resnet.py模型部分代码比fully_connected_feed.py封装程度更高!!!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值