一、下载和运行
https://github.com/tensorflow/models 页面即可下载
具体项目是 models/tutorials/image/cifar10_estimator/
$ curl -o cifar-10-python.tar.gz https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
$ tar xzf cifar-10-python.tar.gz
$ python generate_cifar10_tfrecords.py --input_dir=/home/jqh/jiangqiuhua/Learning_Data/cifar-10/cifar-10-batches-py --output_dir=/home/jqh/jiangqiuhua/Learning_Data/cifar-10/
python cifar10_main.py --data_dir=/home/jqh/jiangqiuhua/Learning_Data/cifar-10 \
--model_dir=/tmp/cifar10 \
--is_cpu_ps=True \
--force_gpu_compatible=True \
--num_gpus=1 \
--train_steps=10000
$ tensorboard --logdir=/tmp/cifar10
二、代码分析
1.cifar10_main.py
1.1 命令行参数处理
FLAGS = tf.flags.FLAGS
tf.flags.FLAGS定义在/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform文件夹下flags.py中。
FLAGS = _FlagValues()
class _FlagValues(object):
def _parse_flags(self, args=None):
result, unparsed = _global_parser.parse_known_args(args=args)
def __getattr__(self, name):
if not parsed:
self._parse_flags()
python中如下代码的作用。
if __name__ = "__main__": #使用这种方式保证了,如果此文件被其它文件import的时候,不会执行main中的代码
tf.app.run() #解析命令行参数,调用main函数 main(sys.argv)
在tf.app.run()中
flags_passthrough = f._parse_flags(args=args)
1.2 训练和评估
1)训练和评估输入
train_input_fn = functools.partial(input_fn, subset='train',
num_shards=FLAGS.num_gpus)
eval_input_fn = functools.partial(input_fn, subset='eval',
num_shards=FLAGS.num_gpus)
functools.partial的作用就是表明train_input_fn函数就是带了train和FLAGS.num_gpus参数的input_fn函数。
2)Session配置
sess_config = tf.ConfigProto()
sess_config.allow_soft_placement = True
sess_config.log_device_placement = FLAGS.log_device_placement
sess_config.intra_op_parallelism_threads = FLAGS.num_intra_threads
sess_config.inter_op_parallelism_threads = FLAGS.num_inter_threads
sess_config.gpu_options.force_gpu_compatible = FLAGS.force_gpu_compatible
3)Estimator配置
config = tf.estimator.RunConfig()
config = config.replace(session_config=sess_config)
classifier = tf.estimator.Estimator(
model_fn=_resnet_model_fn, model_dir=FLAGS.model_dir, config=config)
4)训练和评估
classifier.train(input_fn=train_input_fn,
steps=train_steps,
hooks=hooks)
eval_results = classifier.evaluate(
input_fn=eval_input_fn,
steps=eval_steps)