一、main()
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
tf.gfile.MakeDirs(FLAGS.train_logdir)
tf.logging.info('Training on %s set', FLAGS.train_split)
graph = tf.Graph()
with graph.as_default():
with tf.device(tf.train.replica_device_setter(ps_tasks=FLAGS.num_ps_tasks)):
assert FLAGS.train_batch_size % FLAGS.num_clones == 0, (
'Training batch size not divisble by number of clones (GPUs).')
clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones
dataset = data_generator.Dataset(
dataset_name=FLAGS.dataset,
split_name=FLAGS.train_split,
dataset_dir=FLAGS.dataset_dir,
batch_size=clone_batch_size,
crop_size=[int(sz) for sz in FLAGS.train_crop_size],
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
min_scale_factor=FLAGS.min_scale_factor,
max_scale_factor=FLAGS.max_scale_factor,
scale_factor_step_size=FLAGS.scale_factor_step_size,
model_variant=FLAGS.model_variant,
num_readers=2,
is_training=True,
should_shuffle=True,
should_repeat=True)
train_tensor, summary_op = _train_deeplab_model(
dataset.get_one_shot_iterator(), dataset.num_of_classes,
dataset.ignore_label)
session_config = tf.ConfigProto(
allow_soft_placement=True, log_device_placement=False)
last_layers = model.get_extra_layer_scopes(
FLAGS.last_layers_contain_logits_only)
init_fn = None
if FLAGS.tf_initial_checkpoint:
init_fn = train_utils.get_model_init_fn(
FLAGS.train_logdir,
FLAGS.tf_initial_checkpoint,
FLAGS.initialize_last_layer,
last_layers,
ignore_missing_vars=True)
scaffold = tf.train.Scaffold(
init_fn=init_fn,
summary_op=summary_op,
)
stop_hook = tf.train.Sto