1.题记
因为课程设计要用到TensorFlow,所以这几天在看TensorFlow官方给的几个示例代码,前面几个例子比较简单,看完之后试着重新写了一下,从CIFAR10开始我觉得需要做点笔记了。官网的例子是分成好几个模块实现。
图片来自TensorFlow中文社区
这样的好处是从程序设计的角度来说比较优美,但是从我们看代码的角度来说,不利于阅读逻辑的连贯性。我个人看代码的习惯是先把主线上的执行逻辑理顺了,之后再考虑其他的执行分支和一些细节,其他像属于优化或者扩展,最大可用性等等的代码放到后面看。
所以我的笔记和官网的代码与教程的不同之处在于,我把Demo的主要逻辑捋直了,然后加上了比较详细地注释,感兴趣的同学可以顺着看下来,不用再跳来跳去,应该可以稍微加快阅读这个Demo的速度。
2. cifar10介绍
CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题。任务的目标对一组32x32 RGB的图像进行分类,这个数据集涵盖了10个类别:飞机, 汽车, 鸟, 猫, 鹿, 狗, 青蛙, 马, 船以及卡车。
谷歌这份Demo的目标是建立一个用于识别图像的相对较小的卷积神经网络。选择CIFAR-10是因为它的复杂程度足以用来检验TensorFlow中的大部分功能,并可将其扩展为更大的模型。与此同时由于模型较小所以训练速度很快,比较适合用来测试新的想法,检验新的技术。
3. 代码
1.导入库
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import os.path
import time
import os
import re
import sys
import tarfile
import tensorflow.python.platform
from tensorflow.python.platform import gfile
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
2.部分全局变量
# 定义全局变量
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
"""Directory where to write event logs """
"""and checkpoint.""")
tf.app.flags.DEFINE_integer('max_steps', 1000000,
"""Number of batches to run.""")
tf.app.flags.DEFINE_boolean('log_device_placement', False,
"""Whether to log device placement.""")
# 基本模型参数
tf.app.flags.DEFINE_integer('batch_size', 128,
"""Number of images to process in a batch.""")
tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data',
"""Path to the CIFAR-10 data directory.""")
tf.app.flags.DEFINE_integer('log_frequency', 10,
"""How often to log results to the console.""")
DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
3.下载数据集
国内网络环境的原因,这段可能执行不成功,大家可以去官网下载好数据集,然后放置在/tmp/cifar10_data路径下,解压后的效果如下:
数据集存放位置
# 检测本地是否有数据集
def maybe_download_and_extract():
"""Download and extract the tarball from Alex's website."""
dest_directory = FLAGS.data_dir # /tmp/cifar10_data
if not os.path.exists(dest_directory):
os.makedirs(dest_directory)
# 从URL中获得文件名
filename = DATA_URL.split('/')[-1]
# 合并文件路径
filepath = os.path.join(dest_directory, filename)
if not os.path.exists(filepath):
# 定义下载过程中打印日志的回调函数
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
# 下载数据集
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath,
reporthook=_progress)
print()
# 获得文件信息
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
# 解压缩
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
maybe_download_and_extract()
4.
# 删除之前训练过程中产生的一些临时文件,并重新生成目录
if gfile.Exists(FLAGS.train_dir):
gfile.DeleteRecursively(FLAGS.train_dir)
gfile.MakeDirs(FLAGS.train_dir)
#定义记录训练步数的变量
global_step = tf.train.get_or_create_global_step()# tf.Variable(0, trainable=False)
5. 导入数据和标签
# 从 CIFAR-10 中导入数据和标签
IMAGE_SIZE = 24
NUM_CLASSES = 10
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
def distorted_inputs(batch_size):
"""
Returns:
i