class batch_renorm():
def __init__(self, n_out,renorm_momentum=0.97, RMAX=1, DMAX=0,epsilon=1e-3):
self.n_out = n_out
self.moving_mean = _variable_on_cpu('moving_mean', [self.n_out],
initializer=tf.zeros_initializer,
train=False)
self.moving_variance = _variable_on_cpu('moving_variance', [self.n_out],
initializer=tf.ones_initializer,
train=False)
self.epsilon=epsilon
self.RMAX=RMAX
self.DMAX=DMAX
self.renorm_momentum=renorm_momentum
def __call__(self, inputs, train=True):
beta = tf.Variable(tf.constant(0.0, shape=[self.n_out]),
name='beta', trainable=True)
gamma = tf.Variable(tf.constant(1.0, shape=[self.n_out]),
name='gamma', trainable=True)
def _batch_norm_training():
batch_mean, batch_variance = tf.nn.moments(inputs, [0, 1, 2], name='moments')
# new_mean, new_variance推测应该是batch_mean/variance
from tensorflow.python.ops import math_ops
moving_inv = math_ops.rsqrt(self.moving_variance + self.epsilon)
r = tf.stop_gradient(tf.clip_by_value(tf.sqrt(batch_variance + self.epsilon)*moving_inv,
1 / self.RMAX,
self.RMAX))
d = tf.stop_gradient(tf.clip_by_value((batch_mean - self.moving_mean) * moving_inv,
-self.DMAX,
self.DMAX))
scale = tf.stop_gradient(r, name='renorm_r')
offset = tf.stop_gradient(d, name='renorm_d')
if gamma is not None:
scale *= gamma
offset *= gamma
if beta is not None:
offset += beta
with tf.control_dependencies([assign_moving_average(self.moving_mean, batch_mean,self.renorm_momentum),
assign_moving_average(self.moving_variance, batch_variance,self.renorm_momentum)]):
return tf.nn.batch_normalization(inputs, batch_mean, batch_variance, offset, scale, self.epsilon)
def _batch_norm_inference():
return tf.nn.batch_normalization(
inputs,
mean=self.moving_mean,
variance=self.moving_variance,
offset=beta,
scale=gamma,
variance_epsilon=self.epsilon)
train=tf.convert_to_tensor(train)
output=tf.cond(train,_batch_norm_training,_batch_norm_inference)
return output
batchrenorm代码
最新推荐文章于 2023-02-21 16:53:28 发布