Tensorflow 构建batch normalization时变量gama不被训练的问题

跑完GAN之后,重载模型,观察训练变量列表。

tvars = tf.trainable_variables()
for i in tvars:
    print(i)


### output ###
<tf.Variable 'g_w1:0' shape=(100, 32768) dtype=float32_ref>
<tf.Variable 'g_b1:0' shape=(32768,) dtype=float32_ref>
<tf.Variable 'bn1_g/beta:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'g_w2:0' shape=(4, 4, 128, 128) dtype=float32_ref>
<tf.Variable 'g_b2:0' shape=(128,) dtype=float32_ref>
<tf.Variable 'bn2_g/beta:0' shape=(128,) dtype=float32_ref>
###

 

我就很奇怪,命名空间'bn1_g'下怎么也该有个gama吧,只有beta像什么话?

去找原文

# network.py下
g2 = tf.contrib.layers.batch_norm(g2, epsilon=1e-5, scope='bn2_g')


#其中contrib.layers.batch_norm()定义在
#\lib\site-packages\tensorflow\contrib\layers.py文件中

找到contrib.layers.py文件

...
#在函数中找到这么一段
  layer = normalization_layers.BatchNormalization(
	  axis=axis,
	  momentum=decay,
	  epsilon=epsilon,
	  center=center,
	  scale=scale,
	  beta_initializer=beta_initializer,
	  gamma_initializer=gamma_initializer,
	  moving_mean_initializer=moving_mean_initializer,
	  moving_variance_initializer=moving_variance_initializer,
	  beta_regularizer=beta_regularizer,
	  gamma_regularizer=gamma_regularizer,
	  trainable=trainable,
	  renorm=renorm,
	  renorm_clipping=renorm_clipping,
	  renorm_momentum=renorm_decay,
	  adjustment=adjustment,
	  name=sc.name,
	  _scope=sc,
	  _reuse=reuse,
	  fused=fused)
  outputs = layer.apply(inputs, training=is_training)
  print("==========i am here========") #正常输出
  # Add variables to collections.
  _add_variable_to_collections(layer.moving_mean, variables_collections,
							   'moving_mean')
  _add_variable_to_collections(layer.moving_variance, variables_collections,
							   'moving_variance')
  if layer.beta is not None:
	_add_variable_to_collections(layer.beta, variables_collections, 'beta')
  if layer.gamma is not None:
	print("==========hello gama is added==============") #没有输出
	_add_variable_to_collections(layer.gamma, variables_collections,
								 'gamma')
  else:
	  print("==========sorry layer.gama is none========") #成功输出

可见gama不在训练变量中的根本原因是,调用 normalization_layers.BatchNormalization()创建layer对象的时候出错了。

#在文件开头找到
from tensorflow.python.layers import normalization as normalization_layers

#于是我们来到
\Lib\site-packages\tensorflow\python\layers\normalization.py

#找到该文件下BatchNormalization()类的初始化函数


@tf_export(v1=['layers.BatchNormalization'])
class BatchNormalization(keras_layers.BatchNormalization, base.Layer):
    def __init__(self,
               axis=-1,
               momentum=0.99,
               epsilon=1e-3,
               center=True,
               scale=True,
               beta_initializer=init_ops.zeros_initializer(),
               gamma_initializer=init_ops.ones_initializer(),
               moving_mean_initializer=init_ops.zeros_initializer(),
               moving_variance_initializer=init_ops.ones_initializer(),
               beta_regularizer=None,
               gamma_regularizer=None,
               beta_constraint=None,
               gamma_constraint=None,
               renorm=False,
               renorm_clipping=None,
               renorm_momentum=0.99,
               fused=None,
               trainable=True,
               virtual_batch_size=None,
               adjustment=None,
               name=None,
               **kwargs):
    super(BatchNormalization, self).__init__(
        axis=axis,
        momentum=momentum,
        epsilon=epsilon,
        center=center,
        scale=scale,
        beta_initializer=beta_initializer,
        gamma_initializer=gamma_initializer,
        moving_mean_initializer=moving_mean_initializer,
        moving_variance_initializer=moving_variance_initializer,
        beta_regularizer=beta_regularizer,
        gamma_regularizer=gamma_regularizer,
        beta_constraint=beta_constraint,
        gamma_constraint=gamma_constraint,
        renorm=renorm,
        renorm_clipping=renorm_clipping,
        renorm_momentum=renorm_momentum,
        fused=fused,
        trainable=trainable,
        virtual_batch_size=virtual_batch_size,
        adjustment=adjustment,
        name=name,
        **kwargs)

看起来很长,核心其实就一句话。

在构造一个BatchNormalization类的对象时,
调用它的父类的构造函数
 super(BatchNormalization, self).__init__()

于是我们要寻找BatchNormalization的父类在哪......

python这个依赖是真的难受,好多依赖关系找不到,不能直接跳转。

#我们注意到,该类的继承来自keras_layers.BatchNormalization
class BatchNormalization(keras_layers.BatchNormalization, base.Layer):

#文件头又有
from tensorflow.python.keras import layers as keras_layers

#于是我们来到目录
\Lib\site-packages\tensorflow\python\keras\layers\

#又一番搜索,才找到
\Lib\site-packages\tensorflow\python\keras\layers\normalization.py


#终于看到你了
@tf_export('keras.layers.BatchNormalization', v1=[])
class BatchNormalizationV2(Layer):
    def __init__(self,
               axis=-1,
               momentum=0.99,
               epsilon=1e-3,
               center=True,
               scale=True,
               beta_initializer='zeros',
               gamma_initializer='ones',
               moving_mean_initializer='zeros',
               moving_variance_initializer='ones',
               beta_regularizer=None,
               gamma_regularizer=None,
               beta_constraint=None,
               gamma_constraint=None,
               renorm=False,
               renorm_clipping=None,
               renorm_momentum=0.99,
               fused=None,
               trainable=True,
               virtual_batch_size=None,
               adjustment=None,
               name=None,
               **kwargs):
        super(BatchNormalizationV2, self).__init__(
            name=name, trainable=trainable, **kwargs)

结果这个BatchNormalizationV2()的构造函数也是调用父类的构造......

继续继续。

#文件头看到
from tensorflow.python.keras.engine.base_layer import Layer

#于是来到
\Lib\site-packages\tensorflow\python\keras\engine\base_layer.py


#找到了BatchNormalizationV2的父类Layer的定义

@tf_export('keras.layers.Layer')
class Layer(checkpointable.CheckpointableBase):

看了一下Layer的构造函数,发现跟gama毫无关系

  @checkpointable.no_automatic_dependency_tracking
  def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
    # These properties should be set by the user via keyword arguments.
    # note that 'dtype', 'input_shape' and 'batch_input_shape'
    # are only applicable to input layers: do not pass these keywords
    # to non-input layers.
    allowed_kwargs = {
        'input_shape',
        'batch_input_shape',
        'batch_size',
        'weights',
        'activity_regularizer',
    }
    # Validate optional keyword arguments.
    for kwarg in kwargs:
      if kwarg not in allowed_kwargs:
        raise TypeError('Keyword argument not understood:', kwarg)

    # Mutable properties
    # Indicates whether the layer's weights are updated during training
    # and whether the layer's updates are run during training
    self.trainable = trainable
    # A stateful layer is a layer whose updates are run during inference too,
    # for instance stateful RNNs.
    self.stateful = False
    # Indicates whether `build` needs to be called upon layer call, to create
    # the layer's weights.
    self.built = False
    # Provides information about which inputs are compatible with the layer.
    self.input_spec = None
    self.supports_masking = False

    self._init_set_name(name)
    self._activity_regularizer = kwargs.pop('activity_regularizer', None)
    if not hasattr(self, '_trainable_weights'):
      self._trainable_weights = []
    if not hasattr(self, '_non_trainable_weights'):
      self._non_trainable_weights = []
    self._updates = []
    # A list of zero-argument lambdas which return Tensors, used for variable
    # regularizers.
    self._callable_losses = []
    # A list of symbolic Tensors containing activity regularizers and losses
    # manually added through `add_loss` in graph-building mode.
    self._losses = []
    # A list of loss values containing activity regularizers and losses
    # manually added through `add_loss` during eager execution. It is cleared
    # after every batch.
    # Because we plan on eventually allowing a same model instance to be trained
    # in eager mode or graph mode alternatively, we need to keep track of
    # eager losses and symbolic losses via separate attributes.
    self._eager_losses = []
    # A list of metric instances corresponding to the symbolic metric tensors
    # added using the `add_metric` API.
    self._metrics = []
    # TODO(psv): Remove this property.
    # A dictionary that maps metric names to metric result tensors. The results
    # are the running averages of metric values over an epoch.
    self._metrics_tensors = {}
    self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
    self._call_fn_args = function_utils.fn_args(self.call)
    self._compute_previous_mask = ('mask' in self._call_fn_args or
                                   hasattr(self, 'compute_mask'))
    self._call_convention = (base_layer_utils
                             .CallConvention.EXPLICIT_INPUTS_ARGUMENT)
    if not hasattr(self, '_layers'):
      self._layers = []  # Dependencies tracked via attribute assignment.

    # These lists will be filled via successive calls
    # to self._add_inbound_node().
    self._inbound_nodes = []
    self._outbound_nodes = []

    call_argspec = tf_inspect.getfullargspec(self.call)
    if 'training' in call_argspec.args:
      self._expects_training_arg = True
    else:
      self._expects_training_arg = False

    # Whether the `call` method can be used to build a TF graph without issues.
    self._call_is_graph_friendly = True

    # Manage input shape information if passed.
    if 'input_shape' in kwargs or 'batch_input_shape' in kwargs:
      # In this case we will later create an input layer
      # to insert before the current layer
      if 'batch_input_shape' in kwargs:
        batch_input_shape = tuple(kwargs['batch_input_shape'])
      elif 'input_shape' in kwargs:
        if 'batch_size' in kwargs:
          batch_size = kwargs['batch_size']
        else:
          batch_size = None
        batch_input_shape = (batch_size,) + tuple(kwargs['input_shape'])
      self._batch_input_shape = batch_input_shape

    # Manage initial weight values if passed.
    if 'weights' in kwargs:
      self._initial_weights = kwargs['weights']
    else:
      self._initial_weights = None

 

于是我们再次回到......

 

#回到BatchNormalizationV2类
#\Lib\site-packages\tensorflow\python\keras\layers\normalization.py

"""
    #上面一堆无关
"""
def __init__(...):  #一堆参数
    super(BatchNormalizationV2, self).__init__(
        name=name, trainable=trainable, **kwargs)
    #上面这段已知与gama无关,就不管了

"""
    #中间一堆无关的赋值
"""
    #重头戏
    self.beta_initializer = initializers.get(beta_initializer)
    self.gamma_initializer = initializers.get(gamma_initializer)
	
"""
    #下面一堆无关
"""

把重头戏拆出来,其实就一行。

self.gamma_initializer = initializers.get(gamma_initializer)

值得注意的是,这个文件是基于keras框架的。

所以定义要去keras里面找。

#千辛万苦找到了
#\Lib\site-packages\tensorflow\python\keras\initializers.py


@tf_export('keras.initializers.get')
def get(identifier):
  if identifier is None:
    return None
  if isinstance(identifier, dict):
    return deserialize(identifier)
  elif isinstance(identifier, six.string_types):
    config = {'class_name': str(identifier), 'config': {}}
    return deserialize(config)
  elif callable(identifier):
    return identifier
  else:
    raise ValueError('Could not interpret initializer identifier: ' +
                     str(identifier))

看这个定义,好像只有gamma_initializer为none才会返回none啊。

 

#测试
    self.gamma_initializer = initializers.get(gamma_initializer)
    if self.gamma_initializer is None:
        print("after get , gama_initializer is None")
    else:
        print("gama_ini is " + str(self.gamma_initializer))

#输出
gama_ini is <tensorflow.python.ops.init_ops.Ones object at 0x000000001EEE0438

但测试结果显示,gamma_initializer没什么问题......

 

我们回到最初的起点

#layers.py
#在它创建完layer对象后,立刻检查gamma

layer = normalization_layers.BatchNormalization(...)

if layer.gamma is None:
          print("==========sorry layer.gama is none========")

print("going to make layer.apply")
outputs = layer.apply(inputs, training=is_training)



#输出
'BatchNormalization' object has no attribute 'gamma'

#可见创建完layer之后,gamma还没有产生
#进一步测试可以知道,此时beta也没有产生
#所以一切根源都在这行layer.apply里面

 

为了研究layer.apply,我们继续寻找。

#layer.py
#创建完layer对象后马上运算了outputs
outputs = layer.apply(inputs, training=is_training)


#其中,apply的定义在
#\Lib\site-packages\tensorflow\python\keras\engine\base_layer.py
#Layer类

def apply(self, inputs, *args, **kwargs):
    """Apply the layer on a input.

    This is an alias of `self.__call__`.

    Arguments:
      inputs: Input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.

    Returns:
      Output tensor(s).
    """
    return self.__call__(inputs, *args, **kwargs)


#再看self.__call__()

def __call__(self, inputs, *args, **kwargs):
    """
    不要忘记,我们只是在base_layer.py文件内。
    所以这只是一个标准入口,并不定义怎么计算具体的call()。
    call在神经网络中,大部分时候指"前向传播"。
    调用call()就是想计算这一层的结果,call()的return也是output
    我们从BN层跑来,自然想知道BN后的output。
    但是BN的算法却不是在这个底层文件定义的。
    """
    #先展开输入
    input_list = nest.flatten(inputs)

    if context.executing_eagerly():
      # Accept NumPy inputs by converting to Tensors when executing eagerly.
      if all(isinstance(x, (np.ndarray, float, int)) for x in input_list):
        inputs = nest.map_structure(ops.convert_to_tensor, inputs)
        input_list = nest.flatten(inputs)

    #判断是不是input里面所有张量都是tf的标准symbolic_tensor
    build_graph = tf_utils.are_all_symbolic_tensors(input_list)
    executing_eagerly = context.executing_eagerly()

    previous_mask = None
    if build_graph and (not hasattr(self, '_compute_previous_mask') or
                        self._compute_previous_mask):
      previous_mask = base_layer_utils.collect_previous_mask(inputs)
      if not hasattr(self, '_call_fn_args'):
        self._call_fn_args = function_utils.fn_args(self.call)
      if ('mask' in self._call_fn_args and 'mask' not in kwargs and
          not generic_utils.is_all_none(previous_mask)):
        # The previous layer generated a mask, and mask was not explicitly pass
        # to __call__, hence we set previous_mask as the default value.
        kwargs['mask'] = previous_mask

    input_shapes = None

    with ops.name_scope(self._name_scope()):
"""
      #划重点
"""
      if not self.built:
        # Build layer if applicable (if the `build` method has been overridden).
"""
        #试探性的调用可能存在的我们自定义的build()
"""
        self._maybe_build(inputs)

        # We must set self.built since user defined build functions are not
        # constrained to set self.built.

        #这句是贴心的,以防止我们自定义的build函数没有设置built标志
        self.built = True  

      # Check input assumptions set after layer building, e.g. input shape.
      if build_graph:
        # Symbolic execution on symbolic tensors. We will attempt to build
        # the corresponding TF subgraph inside `backend.get_graph()`
        input_spec.assert_input_compatibility(
            self.input_spec, inputs, self.name)
        graph = backend.get_graph()
"""
        #重点
"""
        with graph.as_default():
          if not executing_eagerly:
            # In graph mode, failure to build the layer's graph
            # implies a user-side bug. We don't catch exceptions.
            outputs = self.call(inputs, *args, **kwargs)
"""
            #这个__call__入口通过调用子类中具体定义的call来计算实际的output
"""
          else:
            try:
              outputs = self.call(inputs, *args, **kwargs)
"""
              #同理。下面都是无关操作了
"""
            except Exception:  # pylint: disable=broad-except
              # Any issue during graph-building means we will later run the
              # model in eager mode, whether the issue was related to
              # graph mode or not. This provides a nice debugging experience.
              self._call_is_graph_friendly = False
              # We will use static shape inference to return symbolic tensors
              # matching the specifications of the layer outputs.
              # Since we have set `self._call_is_graph_friendly = False`,
              # we will never attempt to run the underlying TF graph (which is
              # disconnected).
              # TODO(fchollet): consider py_func as an alternative, which
              # would enable us to run the underlying graph if needed.
              input_shapes = nest.map_structure(lambda x: x.shape, inputs)
              output_shapes = self.compute_output_shape(input_shapes)
              outputs = nest.map_structure(
                  lambda shape: backend.placeholder(shape, dtype=self.dtype),
                  output_shapes)

          if outputs is None:
            raise ValueError('A layer\'s `call` method should return a '
                             'Tensor or a list of Tensors, not None '
                             '(layer: ' + self.name + ').')
          self._handle_activity_regularization(inputs, outputs)
          self._set_mask_metadata(inputs, outputs, previous_mask)
          if base_layer_utils.have_all_keras_metadata(inputs):
            inputs, outputs = self._set_connectivity_metadata_(
                inputs, outputs, args, kwargs)
          if hasattr(self, '_set_inputs') and not self.inputs:
            # Subclassed network: explicitly set metadata normally set by
            # a call to self._set_inputs().
            # This is not relevant in eager execution.
            self._set_inputs(inputs, outputs)
      else:
        # Eager execution on data tensors.
        outputs = self.call(inputs, *args, **kwargs)
        self._handle_activity_regularization(inputs, outputs)
        return outputs

    if not context.executing_eagerly():
      # Optionally load weight values specified at layer instantiation.
      # TODO(fchollet): consider enabling this with eager execution too.
      if (hasattr(self, '_initial_weights') and
          self._initial_weights is not None):
        self.set_weights(self._initial_weights)
        del self._initial_weights
    return outputs

 

现在我们可以从表象上非常清晰的理解,为什么layer.py文件中,刚创建完的layer是没有beta和gamma的。

因为刚创建的layer用的是base_layer.py中的底层类Layer()的构造函数,显然其结构是非常简单的。

我们写base_layer的时候显然不知道,以后创建的会是BN层还是Convolution层还是什么什么层。

怎么会知道需要创建一个gamma和一个beta呢?

 

所以与BN层密切相关的两个参数gamma和beta的创建,显然要在与BN层直接相关的类里面定义。

顺着这个思路,在上面的__call__()中,与用户自定义相关的函数有俩个。

一个是build(),一个是call()。

先来看call()。

#我们再次回到
#\Lib\site-packages\tensorflow\python\keras\layers\normalization.py
#在这个Layer的子类下面找到
class BatchNormalizationV2(Layer):
"""
"""
  def call(self, inputs, training=None):
    if training is None:
      training = K.learning_phase()

    in_eager_mode = context.executing_eagerly()
	
    #这个表示在传进来的batch_size中,再分成几组不同的v_b_size,组内进行BN
    if self.virtual_batch_size is not None:
      # Virtual batches (aka ghost batches) can be simulated by reshaping the
      # Tensor and reusing the existing batch norm implementation

      #截断了第0维即样本数,然后用-1填充,这样剩下的正数部分就是具体的形状
      original_shape = [-1] + inputs.shape.as_list()[1:] 
      #>> [-1,dim1.dim2,dim3,...]

      #增加一维virtual_batch_size
      expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]
      #>> [v_b_size,-1,dim1,dim2,dim3,...]

      # Will cause errors if virtual_batch_size does not divide the batch size
      # 把 [b_size,dim1,dim2,...] reshape成 [v_b_size,-1,dim1,dim2,dim3]
      # 由于-1表征的维度是自动计算的,所以只要batch_size能整除v_b_size即可
      # 使结果满足, batch_size = v_b_size*(-1) ,其中-1表示自动计算的某个整数
      inputs = array_ops.reshape(inputs, expanded_shape)

      #定义还原变形的方法
      def undo_virtual_batching(outputs):
        outputs = array_ops.reshape(outputs, original_shape)
        return outputs
"""
    #如果fused==True,就要使用一种faster & fused implementation
    #产出output,然后return,结束了call()
"""
    if self.fused:
	  #从结构推测就是下面这个东西_fused_batch_norm
      outputs = self._fused_batch_norm(inputs, training=training)
      if self.virtual_batch_size is not None:
        # Currently never reaches here since fused_batch_norm does not support
        # virtual batching
		#还原变形
        outputs = undo_virtual_batching(outputs)
      return outputs
"""
    #如果fused==False or None
    #就按照下面的传统方法运作
"""  
    # Compute the axes along which to reduce the mean / variance
    input_shape = inputs.get_shape()
    ndims = len(input_shape)
    reduction_axes = [i for i in range(ndims) if i not in self.axis]
    if self.virtual_batch_size is not None:
      del reduction_axes[1]     # Do not reduce along virtual batch dim

    # Broadcasting only necessary for single-axis batch norm where the axis is
    # not the last dimension
    broadcast_shape = [1] * ndims
    broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value
    def _broadcast(v):
      if (v is not None and
          len(v.get_shape()) != ndims and
          reduction_axes != list(range(ndims - 1))):
        return array_ops.reshape(v, broadcast_shape)
      return v

    scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

    def _compose_transforms(scale, offset, then_scale, then_offset):
      if then_scale is not None:
        scale *= then_scale
        offset *= then_scale
      if then_offset is not None:
        offset += then_offset
      return (scale, offset)

    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = tf_utils.constant_value(training)
    if training_value is not False:
      if self.adjustment:
        adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
        # Adjust only during training.
        adj_scale = tf_utils.smart_cond(training,
                                        lambda: adj_scale,
                                        lambda: array_ops.ones_like(adj_scale))
        adj_bias = tf_utils.smart_cond(training,
                                       lambda: adj_bias,
                                       lambda: array_ops.zeros_like(adj_bias))
        scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset)

      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
      mean, variance = self._moments(
          inputs, reduction_axes, keep_dims=keep_dims)

      moving_mean = self.moving_mean
      moving_variance = self.moving_variance

      mean = tf_utils.smart_cond(training,
                                 lambda: mean,
                                 lambda: moving_mean)
      variance = tf_utils.smart_cond(training,
                                     lambda: variance,
                                     lambda: moving_variance)

      if self.virtual_batch_size is not None:
        # This isn't strictly correct since in ghost batch norm, you are
        # supposed to sequentially update the moving_mean and moving_variance
        # with each sub-batch. However, since the moving statistics are only
        # used during evaluation, it is more efficient to just update in one
        # step and should not make a significant difference in the result.
        new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
        new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True)
      else:
        new_mean, new_variance = mean, variance

      if self.renorm:
        r, d, new_mean, new_variance = self._renorm_correction_and_moments(
            new_mean, new_variance, training)
        # When training, the normalized values (say, x) will be transformed as
        # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
        # = x * (r * gamma) + (d * gamma + beta) with renorm.
        r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
        d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
        scale, offset = _compose_transforms(r, d, scale, offset)

      def _do_update(var, value):
        if in_eager_mode and not self.trainable:
          return

        return self._assign_moving_average(var, value, self.momentum)

      mean_update = tf_utils.smart_cond(
          training,
          lambda: _do_update(self.moving_mean, new_mean),
          lambda: self.moving_mean)
      variance_update = tf_utils.smart_cond(
          training,
          lambda: _do_update(self.moving_variance, new_variance),
          lambda: self.moving_variance)
      if not context.executing_eagerly():
        self.add_update(mean_update, inputs=True)
        self.add_update(variance_update, inputs=True)

    else:
      #training_value is False
      mean, variance = self.moving_mean, self.moving_variance

    mean = math_ops.cast(mean, inputs.dtype)
    variance = math_ops.cast(variance, inputs.dtype)
    if offset is not None:
      offset = math_ops.cast(offset, inputs.dtype)
    outputs = nn.batch_normalization(inputs,
                                     _broadcast(mean),
                                     _broadcast(variance),
                                     offset,
                                     scale,
                                     self.epsilon)
    # If some components of the shape got lost due to adjustments, fix that.
    outputs.set_shape(input_shape)

    if self.virtual_batch_size is not None:
      outputs = undo_virtual_batching(outputs)
    return outputs

我们知道了,keras->BatchNormalizationV2类的call()函数跟一个fused参数有关,可以进行分支选择。

但是很遗憾,call()与我们要研究的gamma没有联系。

 

于是我们把目光投向build()

#依然在这个文件里,build()和call()都在这
#\Lib\site-packages\tensorflow\python\keras\layers\normalization.py
#在这个Layer的子类下面找到
class BatchNormalizationV2(Layer):

  def build(self, input_shape):
    input_shape = tensor_shape.TensorShape(input_shape)
    if not input_shape.ndims:
      raise ValueError('Input has undefined rank:', input_shape)
    ndims = len(input_shape)

    # Convert axis to list and resolve negatives
    if isinstance(self.axis, int):
      self.axis = [self.axis]

    for idx, x in enumerate(self.axis):
      if x < 0:
        self.axis[idx] = ndims + x

    # Validate axes
    for x in self.axis:
      if x < 0 or x >= ndims:
        raise ValueError('Invalid axis: %d' % x)
    if len(self.axis) != len(set(self.axis)):
      raise ValueError('Duplicate axis: %s' % self.axis)

    if self.virtual_batch_size is not None:
      if self.virtual_batch_size <= 0:
        raise ValueError('virtual_batch_size must be a positive integer that '
                         'divides the true batch size of the input Tensor')
      # If using virtual batches, the first dimension must be the batch
      # dimension and cannot be the batch norm axis
      if 0 in self.axis:
        raise ValueError('When using virtual_batch_size, the batch dimension '
                         'must be 0 and thus axis cannot include 0')
      if self.adjustment is not None:
        raise ValueError('When using virtual_batch_size, adjustment cannot '
                         'be specified')

    if self.fused in (None, True):
      # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
      # output back to its original shape accordingly.
      if self._USE_V2_BEHAVIOR:
        if self.fused is None:
          self.fused = (ndims == 4)
        elif self.fused and ndims != 4:
          raise ValueError('Batch normalization layers with fused=True only '
                           'support 4D input tensors.')
      else:
        assert self.fused is not None
        self.fused = (ndims == 4 and self._fused_can_be_used())
      # TODO(chrisying): fused batch norm is currently not supported for
      # multi-axis batch norm and by extension virtual batches. In some cases,
      # it might be possible to use fused batch norm but would require reshaping
      # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is
      # particularly tricky. A compromise might be to just support the most
      # common use case (turning 5D w/ virtual batch to NCHW)

    if self.fused:
      if self.axis == [1]:
        self._data_format = 'NCHW'
      elif self.axis == [3]:
        self._data_format = 'NHWC'
      else:
        raise ValueError('Unsupported axis, fused batch norm only supports '
                         'axis == [1] or axis == [3]')

    # Raise parameters of fp16 batch norm to fp32
    if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16:
      param_dtype = dtypes.float32
    else:
      param_dtype = self.dtype or dtypes.float32

    axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
    for x in axis_to_dim:
      if axis_to_dim[x] is None:
        raise ValueError('Input has undefined `axis` dimension. Input shape: ',
                         input_shape)
    self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim)

    if len(axis_to_dim) == 1 and self.virtual_batch_size is None:
      # Single axis batch norm (most common/default use-case)
      param_shape = (list(axis_to_dim.values())[0],)
    else:
      # Parameter shape is the original shape but with 1 in all non-axis dims
      param_shape = [axis_to_dim[i] if i in axis_to_dim
                     else 1 for i in range(ndims)]
      if self.virtual_batch_size is not None:
        # When using virtual batches, add an extra dim at index 1
        param_shape.insert(1, 1)
        for idx, x in enumerate(self.axis):
          self.axis[idx] = x + 1      # Account for added dimension

"""
#划重点
"""
    if self.scale:
      #当缩放标志scale为True,才将gama列入训练目标
      self.gamma = self.add_weight(
          name='gamma',
          shape=param_shape,
          dtype=param_dtype,
          initializer=self.gamma_initializer,
          regularizer=self.gamma_regularizer,
          constraint=self.gamma_constraint,
          trainable=True)
    else:
      #否则令gamma为None,且运算时用一个constant=1.0的常数代替
      #此处为keras,与trainable_variables无关
      #但None值传到最后的tensorflow宇宙内,就知道不训练gamma了
      self.gamma = None
      if self.fused:
        self._gamma_const = array_ops.constant(
            1.0, dtype=param_dtype, shape=param_shape)

    if self.center:
      self.beta = self.add_weight(
          name='beta',
          shape=param_shape,
          dtype=param_dtype,
          initializer=self.beta_initializer,
          regularizer=self.beta_regularizer,
          constraint=self.beta_constraint,
          trainable=True)
    else:
      self.beta = None
      if self.fused:
        self._beta_const = array_ops.constant(
            0.0, dtype=param_dtype, shape=param_shape)

    try:
      # Disable variable partitioning when creating the moving mean and variance
      if hasattr(self, '_scope') and self._scope:
        partitioner = self._scope.partitioner
        self._scope.set_partitioner(None)
      else:
        partitioner = None
      self.moving_mean = self.add_weight(
          name='moving_mean',
          shape=param_shape,
          dtype=param_dtype,
          initializer=self.moving_mean_initializer,
          synchronization=tf_variables.VariableSynchronization.ON_READ,
          trainable=False,
          aggregation=tf_variables.VariableAggregation.MEAN)

      self.moving_variance = self.add_weight(
          name='moving_variance',
          shape=param_shape,
          dtype=param_dtype,
          initializer=self.moving_variance_initializer,
          synchronization=tf_variables.VariableSynchronization.ON_READ,
          trainable=False,
          aggregation=tf_variables.VariableAggregation.MEAN)

      if self.renorm:
        # Create variables to maintain the moving mean and standard deviation.
        # These are used in training and thus are different from the moving
        # averages above. The renorm variables are colocated with moving_mean
        # and moving_variance.
        # NOTE: below, the outer `with device` block causes the current device
        # stack to be cleared. The nested ones use a `lambda` to set the desired
        # device and ignore any devices that may be set by the custom getter.
        def _renorm_variable(name, shape):
          var = self.add_weight(
              name=name,
              shape=shape,
              dtype=param_dtype,
              initializer=init_ops.zeros_initializer(),
              synchronization=tf_variables.VariableSynchronization.ON_READ,
              trainable=False,
              aggregation=tf_variables.VariableAggregation.MEAN)
          return var

        with distribution_strategy_context.get_distribution_strategy(
        ).colocate_vars_with(self.moving_mean):
          self.renorm_mean = _renorm_variable('renorm_mean', param_shape)
          self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ())
        # We initialize renorm_stddev to 0, and maintain the (0-initialized)
        # renorm_stddev_weight. This allows us to (1) mix the average
        # stddev with the minibatch stddev early in training, and (2) compute
        # the unbiased average stddev by dividing renorm_stddev by the weight.
        with distribution_strategy_context.get_distribution_strategy(
        ).colocate_vars_with(self.moving_variance):
          self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape)
          self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight',
                                                       ())
    finally:
      if partitioner:
        self._scope.set_partitioner(partitioner)
    self.built = True

现在我们确定了gamma,beta是在build()中定义的。

gamma由self.scale参数控制,beta由self.center参数控制。

 

但我们还是想知道为什么产生gamma静默,而beta照常训练的结果。

在class BatchNormalizationV2(Layer)的参数介绍中,这样写道

center: If True, add offset of `beta` to normalized tensor.
        If False, `beta` is ignored.
scale: If True, multiply by `gamma`.
        If False, `gamma` is not used.
        When the next layer is linear (also e.g. `nn.relu`),
        this can be disabled since the scaling
        will be done by the next layer.

 看完说明恍然大悟,我使用的GAN模型,每一层generator卷积后,使用leakyrelu()函数进行激活。

 

没有设置scale参数,而batch_norm()函数的默认参数是

def batch_norm(inputs,
               decay=0.999,
               center=True,
               scale=False,
               ...)

这就是为什么,center对应的beta生效了,scale对应的gamma被ignore了。

 

如果我在调用batch_norm()的地方修改一下

g1 = tf.contrib.layers.batch_norm(g1, epsilon=1e-5, scope='bn1_g')

#修改为

g1 = tf.contrib.layers.batch_norm(g1, scale=True, epsilon=1e-5, scope='bn1_g')

这样就能保证gamma也加入训练变量列表了。

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值