Resnet Cifar-10调试

一、下载和运行

https://github.com/tensorflow/models 页面即可下载
具体项目是 models/tutorials/image/cifar10_estimator/
$ curl -o cifar-10-python.tar.gz https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
$ tar xzf cifar-10-python.tar.gz
$ python generate_cifar10_tfrecords.py --input_dir=/home/jqh/jiangqiuhua/Learning_Data/cifar-10/cifar-10-batches-py --output_dir=/home/jqh/jiangqiuhua/Learning_Data/cifar-10/
python cifar10_main.py --data_dir=/home/jqh/jiangqiuhua/Learning_Data/cifar-10 \
                         --model_dir=/tmp/cifar10 \
                         --is_cpu_ps=True \
                         --force_gpu_compatible=True \
                         --num_gpus=1 \
                         --train_steps=10000
$ tensorboard --logdir=/tmp/cifar10

二、代码分析

1.cifar10_main.py
1.1 命令行参数处理
FLAGS = tf.flags.FLAGS

tf.flags.FLAGS定义在/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform文件夹下flags.py中。

FLAGS = _FlagValues()
class _FlagValues(object):
    def _parse_flags(self, args=None):
        result, unparsed = _global_parser.parse_known_args(args=args)
    def __getattr__(self, name):
        if not parsed:
              self._parse_flags()

python中如下代码的作用。

if __name__ = "__main__": #使用这种方式保证了,如果此文件被其它文件import的时候,不会执行main中的代码
    tf.app.run() #解析命令行参数,调用main函数 main(sys.argv)

在tf.app.run()中
     flags_passthrough = f._parse_flags(args=args)
1.2 训练和评估

1)训练和评估输入

train_input_fn = functools.partial(input_fn, subset='train',
                                  num_shards=FLAGS.num_gpus)
eval_input_fn = functools.partial(input_fn, subset='eval',
                                 num_shards=FLAGS.num_gpus)

functools.partial的作用就是表明train_input_fn函数就是带了train和FLAGS.num_gpus参数的input_fn函数。
2)Session配置

sess_config = tf.ConfigProto()
sess_config.allow_soft_placement = True
sess_config.log_device_placement = FLAGS.log_device_placement
sess_config.intra_op_parallelism_threads = FLAGS.num_intra_threads
sess_config.inter_op_parallelism_threads = FLAGS.num_inter_threads
sess_config.gpu_options.force_gpu_compatible = FLAGS.force_gpu_compatible

3)Estimator配置

config = tf.estimator.RunConfig()
config = config.replace(session_config=sess_config)
classifier = tf.estimator.Estimator(
    model_fn=_resnet_model_fn, model_dir=FLAGS.model_dir, config=config)

4)训练和评估

classifier.train(input_fn=train_input_fn,
                 steps=train_steps,
                 hooks=hooks)
eval_results = classifier.evaluate(
    input_fn=eval_input_fn,
    steps=eval_steps)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值