CIFAR-10 图像识别
CIFAR-10是一个日常生活物体的彩色图像数据集,说起来以前用caffe跑过这个入门,总之这次的对象就是它;
下载
首先用脚本下载;
cifar10_download.py
# coding:utf-8
# 引入当前目录中的已经编写好的cifar10模块
import cifar10
# 引入tensorflow
import tensorflow as tf
# tf.app.flags.FLAGS是TensorFlow内部的一个全局变量存储器,同时可以用于命令行参数的处理
FLAGS = tf.app.flags.FLAGS
# 在cifar10模块中预先定义了f.app.flags.FLAGS.data_dir为CIFAR-10的数据路径
# 我们把这个路径改为cifar10_data
FLAGS.data_dir = 'cifar10_data/'
# 如果不存在数据文件,就会执行下载
cifar10.maybe_download_and_extract()
关于FLAGS,这里写用一个脚本用来认识:
import tensorflow as tf
#三个参数分别是,变量名字,变量值,描述,当然不同类型的变量对应不同类型的函数
tf.app.flags.DEFINE_string('test','hello',"description")
tf.app.flags.DEFINE_integer('int_test',111,"description2")
tf.app.flags.DEFINE_boolean('BOOLtest',True,"description3")
FLAGS=tf.app.flags.FLAGS
def main(_):
print(FLAGS.test)
print(FLAGS.int_test)
print(FLAGS.BOOLtest)
#这一段是用来运行主函数的
if __name__ == '__main__':
tf.app.run()
cifar10官方代码
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import time
import tensorflow as tf
import cifar10
FLAGS = tf.app.flags.FLAGS
#上个例子已经较为清楚的说明了FLAGS的作用
tf.app.flags.DEFINE_string('train_dir', '/home/t64/Desktop/cifar10_data',
"""Directory where to write event logs """
"""and checkpoint.""")
tf.app.flags.DEFINE_integer('max_steps', 10000,
"""Number of batches to run.""")
tf.app.flags.DEFINE_boolean('log_device_placement', False,
"""Whether to log device placement.""")
tf.app.flags.DEFINE_integer('log_frequency', 10,
"""How often to log results to the console.""")
def train():
"""Train CIFAR-10 for a number of steps."""
with tf.Graph().as_default():
global_step = tf.contrib.framework.get_or_create_global_step()
#distorted用于数据增强,有缩放裁剪噪声等,但不会改变图像的标签
images, labels = cifar10.distorted_inputs()
# Build a Graph that computes the logits predictions from the
# inference model.
#接下来几排就是标准的套用模型,步骤
logits = cifar10.inference(images)
# Calculate loss.
loss = cifar10.loss(logits, labels)
# Build a Graph that trains the model with one batch of examples and
# updates the model parameters.
train_op = cifar10.train(loss, global_step)
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print(format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
def main(argv=None): # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)
train()
if __name__ == '__main__':
tf.app.run()
可视化工具tensorboard
在进行训练的时候,指令
tensorboard --logdir cifar10_train/
会得到一个地址:http://ubuntu:6006 (举例)
在浏览器打开就能看到训练的详细信息:
这一篇是记的很水,六月过完一般才写了一篇,这个月都在忙着复习,说来惭愧,不是找理由,待修改。