tensorflow版本的deeplabv3+源码解读1

本文主要解读了TensorFlow实现的Deeplabv3+模型的源码,从整体结构到train.py的各个关键函数,包括_main_函数、_train_deeplab_model、_tower_loss、_build_deeplab和_average_gradients等,旨在帮助读者理解模型的训练逻辑和细节。
摘要由CSDN通过智能技术生成

读源码太痛苦了,各种看不懂。因为刚接触语义分割用了deeplab这个模型,想好好地把源码看一下。读第一遍只能把API查一下,了解函数的作用。这是读的第二遍,把各模块的注释写一下。如果有人有更好地方法读懂源代码,求告知。

1.deeplabv3+整体结构

看一下deeplabv3+整个文件夹结构:
在这里插入图片描述我是从local_test_mobilenetv2.sh作为入口开始读的。

2.train.py

2.1 首先看main函数:

def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO) # 将tensorflow日志信息输出到屏幕

  tf.gfile.MakeDirs(FLAGS.train_logdir) # 创建一个目录,若目录存在则成功,无返回
  tf.logging.info('Training on %s set', FLAGS.train_split) # 打印日志信息,train_split默认为train

  graph = tf.Graph() # 实例化一个graph类
  with graph.as_default(): # 作为整个tensorflow运行环境默认图
    with tf.device(tf.train.replica_device_setter(ps_tasks=FLAGS.num_ps_tasks)): # 指定模型运行的设备,分布式训练.num_ps_tasks默认为0,参数服务器数量
      assert FLAGS.train_batch_size % FLAGS.num_clones == 0, (
          'Training batch size not divisble by number of clones (GPUs).') # num_clones默认为1,train_batch_size默认为8,若除不尽则报错
      clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones # //整数除法
      # dataset/data_generator.py中的Dataset类
      dataset = data_generator.Dataset(
          dataset_name=FLAGS.dataset,
          split_name=FLAGS.train_split,
          dataset_dir=FLAGS.dataset_dir,
          batch_size=clone_batch_size,
          crop_size=[int(sz) for sz in FLAGS.train_crop_size],
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
         scale_factor_step_size=FLAGS.scale_factor_step_size,
          model_variant=FLAGS.model_variant,
          num_readers=2,
          is_training=True,
          should_shuffle=True,
          should_repeat=True)
      # 调用train.py中的_train_deeplab_model函数见2.2。传入的参数为tf.data.Iterator类型的迭代器,类别数,忽略标签
      # 返回更新模型参数的张量和日志操作
      train_tensor, summary_op = _train_deeplab_model(
          dataset.get_one_shot_iterator(), dataset.num_of_classes,
          dataset.ignore_label)

      # Soft placement allows placing on CPU ops without GPU implementation.
      # allow_soft_placement为true时,自动分配cpu和gpu
      session_config = tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=False)
      # 调用model.py中的函数
      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      init_fn = None
      # 若给出预训练模型
      if FLAGS.tf_initial_checkpoint:
        # 调用utils/train_utils.py中的get_model_init_fn,返回从checkpoint初始化的模型参数
        init_fn = train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True)

      scaffold = tf.train.Scaffold(
          init_fn=init_fn,
          summary_op=summary_op,
      )
      # train_number_of_steps默认为30000,训练的迭代次数,stop_hook是在特定步数停止的钩子
      stop_hook = tf.train.StopAtStepHook(
          last_step=FLAGS.training_number_of_steps)
    
     # profile路径,默认NOne
      profile_dir = FLAGS.profile_logdir
      if profile_dir is not None:
        tf.gfile.MakeDirs
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值