TF学习之DeepLabv3+代码阅读6(train_utils)

DeepLabv3+代码阅读之train_utils.py

一、get_model_learning_rate()

def get_model_learning_rate(learning_policy,# Learning rate policy for training.
                            base_learning_rate,# The base learning rate for model training.
                            learning_rate_decay_step, # Decay the base learning rate at a fixed step.
                            learning_rate_decay_factor,# The rate to decay the base learning rate.
                            training_number_of_steps,# Number of steps for training.
                            learning_power,# Power used for 'poly' learning policy.
                            slow_start_step,# Training model with small learning rate for the 
                            				# first few steps.
                            slow_start_learning_rate,# The learning rate employed during slow start.
                            slow_start_burnin_type='none'):# The burnin type for the slow start stage. Can be
      													   #`none` which means no burnin or `linear` which 
      													   # means the learning rate increases linearly from 
      													   # slow_start_learning_rate and reaches
      													   # base_learning_rate after slow_start_steps.
  """Gets model's learning rate.

  Computes the model's learning rate for different learning policy.
  Right now, only "step" and "poly" are supported.
  (1) The learning policy for "step" is computed as follows:
    current_learning_rate = base_learning_rate *
      learning_rate_decay_factor ^ (global_step / learning_rate_decay_step)
  See tf.train.exponential_decay for details.
  (2) The learning policy for "poly" is computed as follows:
    current_learning_rate = base_learning_rate *
      (1 - global_step / training_number_of_steps) ^ learning_power

  """
  global_step = tf.train.get_or_create_global_step()
  adjusted_global_step = global_step

  if slow_start_burnin_type != 'none':
    adjusted_global_step -= slow_start_step

  if learning_policy == 'step':
    learning_rate = tf.train.exponential_decay(
        base_learning_rate,
        adjusted_global_step,
        learning_rate_decay_step,
        learning_rate_decay_factor,
        staircase=True)
  elif learning_policy == 'poly':
    learning_rate = tf.train.polynomial_decay(
        base_learning_rate,
        adjusted_global_step,
        training_number_of_steps,
        end_learning_rate=0,
        power=learning_power)
  else:
    raise ValueError('Unknown learning policy.')

  adjusted_slow_start_learning_rate = slow_start_learning_rate
  if slow_start_burnin_type == 'linear':
    # Do linear burnin. Increase linearly from slow_start_learning_rate and
    # reach base_learning_rate after (global_step >= slow_start_steps).
    adjusted_slow_start_learning_rate = (
        slow_start_learning_rate +
        (base_learning_rate - slow_start_learning_rate) *
        tf.to_float(global_step) / slow_start_step)
  elif slow_start_burnin_type != 'none':
    raise ValueError('Unknown burnin type.')

  # Employ small learning rate at the first few steps for warm start.
  return tf.where(global_step < slow_start_step,
                  adjusted_slow_start_learning_rate, learning_rate)

二、add_softmax_cross_entropy_loss_for_each_scale

对每一个尺度的输出结果计算cross entropy loss
参数:
	scales_to_logits: logits名字到不同尺度的输出的对应,shape: [batch, logits_height, logits_width, num_classes].
	labels: Groundtruth labels, shape: [batch, image_height, image_width, 1].
	num_classes: 类别数
	ignore_label: 忽略的标签编号
	loss_weight: loss的权重(=1.0)
	upsample_logits: 是否对logits上采样
	hard_example_mining_step: default is 0
	top_k_percent_pixels: default is 0
	scope: the scope for the loss.
def add_softmax_cross_entropy_loss_for_each_scale(scales_to_logits,
                                                  labels,
                                                  num_classes,
                                                  ignore_label,
                                                  loss_weight=1.0,
                                                  upsample_logits=True,
                                                  hard_example_mining_step=0,
                                                  top_k_percent_pixels=1.0,
                                                  scope=None):

  if labels is None:
    raise ValueError('No label for softmax cross entropy loss.')
  # outputs_to_scales_to_logits = {k: {} for k in model_options.outputs_to_num_classes}
  # model_options.outputs_to_num_classes = {'semantic':21}
  # outputs_to_scales_to_logits = {'semantic': {'merged_logits': {}}}
  # scales_to_logits = outputs_to_scales_to_logits['semantic'] = {'merged_logits': {}}
  for scale, logits in six.iteritems(scales_to_logits):
    loss_scope = None
    if scope:
      loss_scope = '%s_%s' % (scope, scale)# 'semantic_merged_logits'

    if upsample_logits:
      # Label is not downsampled, and instead we upsample logits.上采样logits,而不是下采样label
      logits = tf.image.resize_bilinear(# 上采样logits用bilinear插值
          logits,
          preprocess_utils.resolve_shape(labels, 4)[1:3],
          align_corners=True)
      scaled_labels = labels
    else:
      # Label is downsampled to the same size as logits.下采样label
      scaled_labels = tf.image.resize_nearest_neighbor(# 下采样label则使用nearest插值
          labels,
          preprocess_utils.resolve_shape(logits, 4)[1:3],
          align_corners=True)
	# 插值算法不包含batch维度,resize之后再加上batch维度
    scaled_labels = tf.reshape(scaled_labels, shape=[-1])
    not_ignore_mask = tf.to_float(tf.not_equal(scaled_labels,			   # not_equal(x,y)返回x!=y元素的真值
                                               ignore_label)) * loss_weight# 提取出label中不被忽略的像素位置
                                               							   # mask并乘权重
    one_hot_labels = tf.one_hot(
        scaled_labels, num_classes, on_value=1.0, off_value=0.0)# 变成one hot label

    if top_k_percent_pixels == 1.0:
      # Compute the loss for all pixels.
      tf.losses.softmax_cross_entropy(
          one_hot_labels,
          tf.reshape(logits, shape=[-1, num_classes]),
          weights=not_ignore_mask,
          scope=loss_scope)# loss_scope
    else:
      logits = tf.reshape(logits, shape=[-1, num_classes])
      weights = not_ignore_mask
      with tf.name_scope(loss_scope, 'softmax_hard_example_mining',
                         [logits, one_hot_labels, weights]):
        one_hot_labels = tf.stop_gradient(
            one_hot_labels, name='labels_stop_gradient')
        pixel_losses = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=one_hot_labels,
            logits=logits,
            name='pixel_losses')
        weighted_pixel_losses = tf.multiply(pixel_losses, weights)
        num_pixels = tf.to_float(tf.shape(logits)[0])
        # Compute the top_k_percent pixels based on current training step.
        if hard_example_mining_step == 0:
          # Directly focus on the top_k pixels.
          top_k_pixels = tf.to_int32(top_k_percent_pixels * num_pixels)
        else:
          # Gradually reduce the mining percent to top_k_percent_pixels.
          global_step = tf.to_float(tf.train.get_or_create_global_step())
          ratio = tf.minimum(1.0, global_step / hard_example_mining_step)
          top_k_pixels = tf.to_int32(
              (ratio * top_k_percent_pixels + (1.0 - ratio)) * num_pixels)
        top_k_losses, _ = tf.nn.top_k(weighted_pixel_losses,
                                      k=top_k_pixels,
                                      sorted=True,
                                      name='top_k_percent_pixels')
        total_loss = tf.reduce_sum(top_k_losses)
        num_present = tf.reduce_sum(
            tf.to_float(tf.not_equal(top_k_losses, 0.0)))
        loss = _div_maybe_zero(total_loss, num_present)
        tf.losses.add_loss(loss)

三、get_model_gradient_multipliers

梯度乘法器为模型的变量调整学习率。对于分割任务,模型通常会从由训练图像分类任务得到的模型中进行微调。
我们通常会对最后一层选取大一些(例如10倍)的学习率。
参数:
	last_layers: 最后一层的域
	last_layer_gradient_multiplier:最后一层的梯度乘法器
返回:
	梯度乘法器的一个映射,{变量:乘法器的值}
def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier):

  gradient_multipliers = {}

  for var in tf.model_variables():
    # Double the learning rate for biases.
    if 'biases' in var.op.name:
      gradient_multipliers[var.op.name] = 2.

    # Use larger learning rate for last layer variables.
    for layer in last_layers:
      if layer in var.op.name and 'biases' in var.op.name:
        gradient_multipliers[var.op.name] = 2 * last_layer_gradient_multiplier
        break
      elif layer in var.op.name:
        gradient_multipliers[var.op.name] = last_layer_gradient_multiplier
        break

  return gradient_multipliers

四、get_model_init_fn

从checkpoint中初始化模型。
参数:
	train_logdir: 储存训练过程的log和checkpoint文件目录
	tf_initial_checkpoint: 用来初始化的checkpoint
	initialize_last_layer: 是否初始化最后一层
	last_layers: 模型的最后一层
	ignore_missing_vars: 忽略checkpoint中没有的变量
返回:
	初始化后的模型
def get_model_init_fn(train_logdir,
                      tf_initial_checkpoint,
                      initialize_last_layer,
                      last_layers,
                      ignore_missing_vars=False):

  if tf_initial_checkpoint is None:
    tf.logging.info('Not initializing the model from a checkpoint.')
    return None

  if tf.train.latest_checkpoint(train_logdir):# 找到latest保存的checkpoint文件
    tf.logging.info('Ignoring initialization; other checkpoint exists')
    return None

  tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint)

  # Variables that will not be restored.
  exclude_list = ['global_step']
  if not initialize_last_layer:
    exclude_list.extend(last_layers)

  variables_to_restore = tf.contrib.framework.get_variables_to_restore(
      exclude=exclude_list)

  if variables_to_restore:
    init_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
        tf_initial_checkpoint,
        variables_to_restore,
        ignore_missing_vars=ignore_missing_vars)
    global_step = tf.train.get_or_create_global_step()

    def restore_fn(unused_scaffold, sess):
      sess.run(init_op, init_feed_dict)
      sess.run([global_step])

    return restore_fn

  return None
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值