TF学习之DeepLabv3+代码阅读1(train)

本文详细解读了使用TensorFlow进行DeepLabv3+模型训练的代码,包括main()函数、_train_deeplab_model()损失计算、_tower_loss模块、_build_deeplab模型构建、_log_summaries日志总结以及_Average_gradients梯度平均等关键步骤。
摘要由CSDN通过智能技术生成

DeepLabv3+代码阅读之train.py

一、main()

def main(unused_argv):# main必须带参数,否则报错:'TypeError: main() takes no arguments (1 given)'; 
                      # main的参数名随意定义,无要求。
  tf.logging.set_verbosity(tf.logging.INFO) # 把日志设置在INFO级别
  tf.gfile.MakeDirs(FLAGS.train_logdir)
  tf.logging.info('Training on %s set', FLAGS.train_split)

  graph = tf.Graph()# tf.Graph()表示实例化了一个类,一个用于tf计算和表示用的数据流图,就是呈现计算图(节点和线,操作和数据)
  					# 的“纸”
  with graph.as_default():# 表示将这个类实例(graph,新生成的图)作为整个tf运行环境的默认图
  						  # 同时tf里面也已经存好的一张默认图,通过tf.get_default_graph()来调用(显示这张默认纸)
    with tf.device(tf.train.replica_device_setter(ps_tasks=FLAGS.num_ps_tasks)):# FLAGS.num_ps_tasks=0
      assert FLAGS.train_batch_size % FLAGS.num_clones == 0, (
          'Training batch size not divisble by number of clones (GPUs).')
      clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones

      dataset = data_generator.Dataset(
          dataset_name=FLAGS.dataset,# pascal_voc_seg
          split_name=FLAGS.train_split,# train
          dataset_dir=FLAGS.dataset_dir,
          batch_size=clone_batch_size,
          crop_size=[int(sz) for sz in FLAGS.train_crop_size],# FLAGS.train_crop_size = ['513', '513']
          													  # crop_size = [513, 513]
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          model_variant=FLAGS.model_variant,
          num_readers=2,
          is_training=True,
          should_shuffle=True,
          should_repeat=True)

      train_tensor, summary_op = _train_deeplab_model(
          dataset.get_one_shot_iterator(), dataset.num_of_classes,
          dataset.ignore_label)

      # Soft placement allows placing on CPU ops without GPU implementation.
      session_config = tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=False)

      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      init_fn = None
      if FLAGS.tf_initial_checkpoint:
        init_fn = train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True)

      scaffold = tf.train.Scaffold(
          init_fn=init_fn,
          summary_op=summary_op,
      )

      stop_hook = tf.train.Sto
DeepLabV3+是一种用于图像语义分割的深度学习模型,它结合了特征金字塔网络(FPN)、空间金字塔池化(SPP)以及改进的ASPP模块。以下是使用PythonTensorFlow库进行简化的DeepLabV3+模型训练的伪代码示例: ```python import tensorflow as tf from tensorflow.keras import Model from deeplabv3_plus import build_deeplabv3_plus # 定义超参数 learning_rate = 0.0001 batch_size = 8 epochs = 50 input_shape = (512, 512, 3) # 输入图片尺寸 # 加载预训练模型作为基础网络 base_model = build_deeplabv3_plus(input_shape) base_model.trainable = False # 预训练层不参与训练 # 构建新的分类头 output_layer = create_classifier_layers(base_model.output, num_classes=NUM_CLASSES) # 创建完整的模型 model = Model(inputs=base_model.input, outputs=output_layer) # 编译模型 optimizer = tf.keras.optimizers.Adam(learning_rate) loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) model.compile(optimizer=optimizer, loss=loss_fn, metrics=['accuracy']) # 数据加载和预处理 train_dataset, val_dataset = load_data(batch_size=batch_size) data_augmentation = get_data_augmentation_pipeline() # 训练模型 for epoch in range(epochs): model.fit( train_dataset, epochs=1, validation_data=val_dataset, callbacks=[EarlyStopping(patience=5), ModelCheckpoint('deeplabv3+.h5', save_best_only=True)], steps_per_epoch=len(train_dataset), validation_steps=len(val_dataset), data_augmentation=data_augmentation ) # 使用最佳权重加载模型 best_weights_path = 'deeplabv3+.h5' model.load_weights(best_weights_path)
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值