from __future__ importabsolute_importfrom __future__ importdivisionfrom __future__ importprint_functionimportosimporttensorflow as tf
flags=tf.app.flags
flags.DEFINE_integer(flag_name='batch_size', default_value=16, docstring='Batch 大小')
flags.DEFINE_string(flag_name='data_dir', default_value='./tfrecords', docstring='数据存放位置')
flags.DEFINE_string(flag_name='model_dir', default_value='./cat&dog_model', docstring='模型存放位置')
flags.DEFINE_integer(flag_name='steps', default_value=1000, docstring='训练步数')
flags.DEFINE_integer(flag_name='classes', default_value=2, docstring='类别数量')
FLAGS=flags.FLAGS
MODES=[tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT]def input_fn(mode, batch_size=1):"""输入函数"""
defparser(serialized_example):"""如何处理数据集中的每一个数据"""
#解析单个example对象
features =tf.parse_single_example(
serialized_example,
features={'image/height': tf.FixedLenFeature([], tf.int64),'image/width