读源码太痛苦了,各种看不懂。因为刚接触语义分割用了deeplab这个模型,想好好地把源码看一下。读第一遍只能把API查一下,了解函数的作用。这是读的第二遍,把各模块的注释写一下。如果有人有更好地方法读懂源代码,求告知。
1.deeplabv3+整体结构
看一下deeplabv3+整个文件夹结构:
我是从local_test_mobilenetv2.sh作为入口开始读的。
2.train.py
2.1 首先看main函数:
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO) # 将tensorflow日志信息输出到屏幕
tf.gfile.MakeDirs(FLAGS.train_logdir) # 创建一个目录,若目录存在则成功,无返回
tf.logging.info('Training on %s set', FLAGS.train_split) # 打印日志信息,train_split默认为train
graph = tf.Graph() # 实例化一个graph类
with graph.as_default(): # 作为整个tensorflow运行环境默认图
with tf.device(tf.train.replica_device_setter(ps_tasks=FLAGS.num_ps_tasks)): # 指定模型运行的设备,分布式训练.num_ps_tasks默认为0,参数服务器数量
assert FLAGS.train_batch_size % FLAGS.num_clones == 0, (
'Training batch size not divisble by number of clones (GPUs).') # num_clones默认为1,train_batch_size默认为8,若除不尽则报错
clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones # //整数除法
# dataset/data_generator.py中的Dataset类
dataset = data_generator.Dataset(
dataset_name=FLAGS.dataset,
split_name=FLAGS.train_split,
dataset_dir=FLAGS.dataset_dir,
batch_size=clone_batch_size,
crop_size=[int(sz) for sz in FLAGS.train_crop_size],
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
min_scale_factor=FLAGS.min_scale_factor,
max_scale_factor=FLAGS.max_scale_factor,
scale_factor_step_size=FLAGS.scale_factor_step_size,
model_variant=FLAGS.model_variant,
num_readers=2,
is_training=True,
should_shuffle=True,
should_repeat=True)
# 调用train.py中的_train_deeplab_model函数见2.2。传入的参数为tf.data.Iterator类型的迭代器,类别数,忽略标签
# 返回更新模型参数的张量和日志操作
train_tensor, summary_op = _train_deeplab_model(
dataset.get_one_shot_iterator(), dataset.num_of_classes,
dataset.ignore_label)
# Soft placement allows placing on CPU ops without GPU implementation.
# allow_soft_placement为true时,自动分配cpu和gpu
session_config = tf.ConfigProto(
allow_soft_placement=True, log_device_placement=False)
# 调用model.py中的函数
last_layers = model.get_extra_layer_scopes(
FLAGS.last_layers_contain_logits_only)
init_fn = None
# 若给出预训练模型
if FLAGS.tf_initial_checkpoint:
# 调用utils/train_utils.py中的get_model_init_fn,返回从checkpoint初始化的模型参数
init_fn = train_utils.get_model_init_fn(
FLAGS.train_logdir,
FLAGS.tf_initial_checkpoint,
FLAGS.initialize_last_layer,
last_layers,
ignore_missing_vars=True)
scaffold = tf.train.Scaffold(
init_fn=init_fn,
summary_op=summary_op,
)
# train_number_of_steps默认为30000,训练的迭代次数,stop_hook是在特定步数停止的钩子
stop_hook = tf.train.StopAtStepHook(
last_step=FLAGS.training_number_of_steps)
# profile路径,默认NOne
profile_dir = FLAGS.profile_logdir
if profile_dir is not None:
tf.gfile.MakeDirs