2017.12 北京 房间内 午饭后 纸上学来终觉浅,绝知此事要躬行
CIFAR-10分类问题是机器学习中一个经典的分类问题,图片库中包含10类物体,每一张图片都是32*32像素的RGB图片.本文就Tensorflow官网上CIFAR-10分类问题的源代码进行详细解释。官网中采用了一个较小的CNN网络实现,但是代码中采用了tensorflow的很多API功能,方便初学者学习。
代码主要包括以下几个部分:
我们从训练模型开始,文件为cifar10_train.py(下面贴出),运行这个脚本就可以进行训练操作。第一次运行的时候,代码会自动下载数据集。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import time
import tensorflow as tf
import cifar10
parser = cifar10.parser
parser.add_argument('--train_dir', type=str, default='/home/liu/NewDisk/LearnTensor/cifar10_train',
help='Directory where to write event logs and checkpoint.')#训练模型的存储位置
parser.add_argument('--max_steps', type=int, default=10000 ,
help='Number of batches to run.')#训练迭代次数
parser.add_argument('--log_device_placement', type=bool, default=False,
help='Whether to log device placement.')
parser.add_argument('--log_frequency', type=int, default=10,
help='How often to log results to the console.')
#train()大多调用cifar10.py的函数,具体实现还要看cifar10.py
def train():
"""Train CIFAR-10 for a number of steps."""
with tf.Graph().as_default():#指定当前图为默认图
global_step = tf.train.get_or_create_global_step()
# Get images and labels for CIFAR-10.
# Force input pipeline to CPU:0 to avoid operations sometimes ending up on
# GPU and resulting in a slow down.
with tf.device('/cpu:0'):
images, labels = cifar10.distorted_inputs()#从cifar10.py文件中的函数,用来获取图片和标签
# Build a Graph that computes the logits predictions from the
# inference model.
logits = cifar10.inference(images)#输入图片,返回linear layer(WX + b),并未添加softmax
# Calculate loss.
loss = cifar10.loss(logits, labels)#输入logits,和标签labels,来计算交叉熵损失
# Build a Graph that trains the model with one batch of examples and
# updates the model parameters.
train_op = cifar10.train(loss, global_step)
class _LoggerHook(tf.train.SessionRunHook): #继承自SessionRunHook
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
with tf.train.MonitoredTrainingSession(#return a MonitoredSession object
checkpoint_dir=FLAGS.train_dir,#A string. Optional path to a directory where to restore variables.
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),#Hook that requests stop at a specified step.
tf.train.NanTensorHook(loss),#Monitors the loss tensor and stops training if loss is NaN.Can either fail with exception or just stop training.
_LoggerHook()],#Optional list of SessionRunHook objects.
config=tf.ConfigProto(#an instance of tf.ConfigProto proto used to configure the session. It's the config argument of constructor of tf.Session
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
def main(argv=None): # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.train_dir):#若存在已训练好的,则删除模型,重新训练
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)
train()#本文件定义的函数
if __name__ == '__main__':
FLAGS = parser.parse_args()
tf.app.run()
其他函数见下一篇…
本系列文章参考了diligent_321的文章“TensorFlow中cnn-cifar10样例代码详解 ”,感谢