Tensorflow Object Detection API 源码分析之 utils/variables_helper.py

Tensorflow Object Detection API 源码分析之 utils/variables_helper.py

# model_lib.py 使用了 get_variables_available_in_checkpoint 函数
# 返回 在 checkpoint 包含的variables,在 model_lib.py 中 从checkpoint恢复
"""Helper functions for manipulating collections of variables during training.
"""
import logging
import re

import tensorflow as tf

slim = tf.contrib.slim


# TODO(derekjchow): Consider replacing with tf.contrib.filter_variables in
# tensorflow/contrib/framework/python/ops/variables.py
def filter_variables(variables, filter_regex_list, invert=False):
  """Filters out the variables matching the filter_regex.

  Filter out the variables whose name matches the any of the regular
  expressions in filter_regex_list and returns the remaining variables.
  Optionally, if invert=True, the complement set is returned.

  Args:
    variables: a list of tensorflow variables.
    filter_regex_list: a list of string regular expressions.
    invert: (boolean).  If True, returns the complement of the filter set; that
      is, all variables matching filter_regex are kept and all others discarded.

  Returns:
    a list of filtered variables.
  """
  kept_vars = []
  variables_to_ignore_patterns = list(filter(None, filter_regex_list))
  for var in variables:
    add = True
    for pattern in variables_to_ignore_patterns:
      if re.match(pattern, var.op.name):
        add = False
        break
    if add != invert:
      kept_vars.append(var)
  return kept_vars


def multiply_gradients_matching_regex(grads_and_vars, regex_list, multiplier):
  """Multiply gradients whose variable names match a regular expression.

  Args:
    grads_and_vars: A list of gradient to variable pairs (tuples).
    regex_list: A list of string regular expressions.
    multiplier: A (float) multiplier to apply to each gradient matching the
      regular expression.

  Returns:
    grads_and_vars: A list of gradient to variable pairs (tuples).
  """
  variables = [pair[1] for pair in grads_and_vars]
  matching_vars = filter_variables(variables, regex_list, invert=True)
  for var in matching_vars:
    logging.info('Applying multiplier %f to variable [%s]',
                 multiplier, var.op.name)
  grad_multipliers = {var: float(multiplier) for var in matching_vars}
  return slim.learning.multiply_gradients(grads_and_vars,
                                          grad_multipliers)


def freeze_gradients_matching_regex(grads_and_vars, regex_list):
  """Freeze gradients whose variable names match a regular expression.

  Args:
    grads_and_vars: A list of gradient to variable pairs (tuples).
    regex_list: A list of string regular expressions.

  Returns:
    grads_and_vars: A list of gradient to variable pairs (tuples) that do not
      contain the variables and gradients matching the regex.
  """
  variables = [pair[1] for pair in grads_and_vars]
  matching_vars = filter_variables(variables, regex_list, invert=True)
  kept_grads_and_vars = [pair for pair in grads_and_vars
                         if pair[1] not in matching_vars]
  for var in matching_vars:
    logging.info('Freezing variable [%s]', var.op.name)
  return kept_grads_and_vars


# model_lib.py 中使用的函数,返回checkpoint可用的variables
def get_variables_available_in_checkpoint(variables,
                                          checkpoint_path,
                                          include_global_step=True):
  """Returns the subset of variables available in the checkpoint.

  Inspects given checkpoint and returns the subset of variables that are
  available in it.

  TODO(rathodv): force input and output to be a dictionary.

  Args:
    variables: a list or dictionary of variables to find in checkpoint.
    checkpoint_path: path to the checkpoint to restore variables from.
    include_global_step: whether to include `global_step` variable, if it
      exists. Default True.

  Returns:
    A list or dictionary of variables.
  Raises:
    ValueError: if `variables` is not a list or dict.
  """
  if isinstance(variables, list):
    variable_names_map = {variable.op.name: variable for variable in variables}
  elif isinstance(variables, dict):
    variable_names_map = variables
  else:
    raise ValueError('`variables` is expected to be a list or dict.')
  ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path)
  ckpt_vars_to_shape_map = ckpt_reader.get_variable_to_shape_map()
  if not include_global_step:
    ckpt_vars_to_shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None)
  vars_in_ckpt = {}
  for variable_name, variable in sorted(variable_names_map.items()):
    if variable_name in ckpt_vars_to_shape_map:
      if ckpt_vars_to_shape_map[variable_name] == variable.shape.as_list():
        vars_in_ckpt[variable_name] = variable
      else:
        logging.warning('Variable [%s] is available in checkpoint, but has an '
                        'incompatible shape with model variable. Checkpoint '
                        'shape: [%s], model variable shape: [%s]. This '
                        'variable will not be initialized from the checkpoint.',
                        variable_name, ckpt_vars_to_shape_map[variable_name],
                        variable.shape.as_list())
    else:
      logging.warning('Variable [%s] is not available in checkpoint',
                      variable_name)
  if isinstance(variables, list):
    return vars_in_ckpt.values()
  return vars_in_ckpt
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值