TensorFlow Models
GitHub:https://github.com/tensorflow/models
Document:https://github.com/jikexueyuanwiki/tensorflow-zh
CIFAR-10 数据集
Web:http://www.cs.toronto.edu/~kriz/cifar.html
目标:(建立一个用于识别图像的相对较小的卷积神经网络)对一组32x32RGB的图像进行分类
数据集:60000张32*32*3的彩色图片,其中50000张训练集,10000张测试集,涵盖10个类别:飞机, 汽车, 鸟, 猫, 鹿, 狗, 青蛙, 马, 船以及卡车
CIFAR-10 模型训练
GitHub:https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10
流程:首先读取图片,对图片预处理,进行数据增强,然后将图片存放到队列中打乱之后用于网络输入。其次构造模型,损失函数计算,学习率指数衰减,计算梯度,用梯度来求解最优值。最后开始训练。
1)导入库
# cifar10_train.py from __future__ import absolute_import from __future__ import division from __future__ import print_function from datetime import datetime import time import tensorflow as tf import cifar10 # cifar10.py from __future__ import absolute_import from __future__ import division from __future__ import print_function import re import tensorflow as tf import cifar10_input import os import sys import urllib import tarfile # cifar10_input.py from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf import tensorflow_datasets as tfds
2)使用FLAGS设置参数
# cifar10_train.py # 定义全局变量 FLAGS = tf.app.flags.FLAGS # 初始化
# 定义参数:tf.app.flags.DEFINE_xxx(参数1,参数2,参数3),xxx为参数类型
参数1为变量名,如train_dir,可通过FLAGS.train_dir取得该变量的值
参数2为默认值
参数3为说明内容,当不设置该变量的值时,通过FLAGS.train_dir取到的是其默认值(/tmp/cifar10_train),若要设置该变量的值,可通过运行时写参数--train_dir '路径'来设置(python cifar10_train.py --train_dir '路径') - 键入-h/--help,则打印说明内容
tf.app.flags.DEFINE_string('train_dir', './tmp/cifar10_train', """Directory where to write event logs """"""and checkpoint.""") tf.app.flags.DEFINE_integer('max_steps', 1000000, """Number of batches to run.""") tf.app.flags.DEFINE_boolean('log_device_placement', False, """Whether to log device placement.""") # cifar10.py' # 基本模型参数 tf.app.flags.DEFINE_integer('batch_size', 128, """Number of images to process in a batch.""") tf.app.flags.DEFINE_boolean('use_fp16', True, """Train the model using fp16.""") tf.app.flags.DEFINE_string('data_dir', './tmp/cifar10_data', """Path to the CIFAR-10 data directory.""") tf.app.flags.DEFINE_integer('log_frequency', 10, """How often to log results to the console.""") DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
3)下载数据集
国内网络环境的原因,源码中下载数据集代码可能执行不成功,可以去官网下载好数据集,然后放置在FLAGS.data_dir路径下即可(需要设置FLAGS.data_dir路径)
下载:cifar-10-binary.tar.gz
下载至:*\tutorials\image\cifar10\tmp\cifar10_data(FLAGS.data_dir ='./tmp/cifar10_data')
# cifar10.py # 检测本地是否有数据集 def maybe_download_and_extract(): """Download and extract the tarball from Alex's website.""" dest_directory = FLAGS.data_dir # /tmp/cifar10_data # 判断文件夹是否存在,不存在则创建 if not os.path.exists(dest_directory): os.makedirs(dest_directory) # 从URL中获得文件名:DATA_URL定义为cifar10数据集下载地址,这里将URL最后一个斜杠后面的内容作为文件名 filename = DATA_URL.split('/')[-1] # 合并文件路径:将文件名与数据文件夹结合得到下载文件存放的路径 filepath = os.path.join(dest_directory, filename) # 判断文件是否存在,如果存在,表明数据集已经下载,就无需再下载,如果还没下载,则通过urllib.request.urlretrieve直接下载文件 if not os.path.exists(filepath): # 定义下载过程中打印日志的回调函数:回调函数用于显示下载进度,下载进度为当前下载量除以总下载量 def _progress(count, block_size, total_size): sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, float(count * block_size) / float(total_size) * 100.0