tensorflow2.0定义InstanceNormalization,BatchRenormalization,GroupNormalization,SwitchableNormalization

使用tensorflow2.0定义InstanceNormalization(论文地址),BatchRenormalization(论文地址),GroupNormalization(论文地址),SwitchableNormalization(论文地址)。由于没有具体测试过这些代码,如果有错误,请原谅。

class InstanceNormalization(tf.keras.layers.Layer):
    def __init__(self,beta_initializer='zeros',gamma_initializer='ones',
                 beta_regularizer=None,gamma_regularizer=None,
                 beta_constraint=None,gamma_constraint=None,epsilon=1e-5,
                 **kwargs):
        super(InstanceNormalization, self).__init__(**kwargs)
        self.epsilon = epsilon
        self.beta_initializer = tf.keras.initializers.get(beta_initializer)
        self.gamma_initializer = tf.keras.initializers.get(gamma_initializer)
        self.beta_regularizer = tf.keras.regularizers.get(beta_regularizer)
        self.gamma_regularizer = tf.keras.regularizers.get(gamma_regularizer)
        self.beta_constraint = tf.keras.constraints.get(beta_constraint)
        self.gamma_constraint = tf.keras.constraints.get(gamma_constraint)
        
    def build(self, input_shape):  
        assert len(input_shape)==4
        shape = (input_shape[-1],)
        self.gamma = self.add_weight(shape=shape,name='gamma',initializer=self.gamma_initializer,
                                regularizer=self.gamma_regularizer,constraint=self.gamma_constraint)
        self.beta = self.add_weight(shape=shape,name='beta',initializer=self.beta_initializer,
                                regularizer=self.beta_regularizer,constraint=self.beta_constraint)
        self.built = True
        
    def call(self,inputs):
        mean,variance = tf.nn.moments(inputs,axes=[1,2])
        mean = tf.reshape(mean,shape=[inputs.shape[0],1,1,inputs.shape[-1]])
        variance = tf.reshape(variance,shape=[inputs.shape[0],1,1,inputs.shape[-1]])
        outputs = (inputs - mean) / tf.sqrt(variance + self.epsilon)
        return outputs*self.gamma + self.beta
    
    def get_config(self):
        config = {
            'epsilon': self.epsilon,
            'beta_initializer': tf.keras.initializers.serialize(self.beta_initializer),
            'gamma_initializer': tf.keras.initializers.serialize(self.gamma_initializer),
            'beta_regularizer': tf.keras.regularizers.serialize(self.beta_regularizer),
            'gamma_regularizer': tf.keras.regularizers.serialize(self.gamma_regularizer),
            'beta_constraint': tf.keras.constraints.serialize(self.beta_constraint),
            'gamma_constraint': tf.keras.constraints.serialize(self.gamma_constraint)
        }
        base_config = super(InstanceNormalization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape
class BatchRenormalization(tf.keras.layers.Layer):
    def __init__(self,momentum=0.99,rmax=3,damx=5,t_delta=1e-3,axis=-1,beta_initializer='zeros',
                 gamma_initializer='ones',beta_regularizer=None,gamma_regularizer=None,
                 beta_constraint=None,gamma_constraint=None,moving_mean_initializer='zeros',
                 moving_variance_initializer='ones',epsilon=1e-3,
                 **kwargs):
        super(BatchRenormalization, self).__init__(**kwargs)
        self.momentum = momentum
        self.rmax = rmax
        self.dmax = dmax
        self.t_delta=t_delta,
        self.axis = axis
        self.epsilon = epsilon
        self.beta_initializer = tf.keras.initializers.get(beta_initializer)
        self.gamma_initializer = tf.keras.initializers.get(gamma_initializer)
        self.beta_regularizer = tf.keras.regularizers.get(beta_regularizer)
        self.gamma_regularizer = tf.keras.regularizers.get(gamma_regularizer)
        self.beta_constraint = tf.keras.constraints.get(beta_constraint)
        self.gamma_constraint = tf.keras.constraints.get(gamma_constraint)
        self.moving_mean_initializer = tf.keras.initializers.get(moving_mean_initializer)
        self.moving_variance_initializer = tf.keras.initializers.get(moving_variance_initializer)
        
    def build(self, input_shape):  
        assert len(input_shape)==4
        shape = (input_shape[self.axis],)
        self.gamma = self.add_weight(shape=shape,name='gamma',initializer=self.gamma_initializer,
                                regularizer=self.gamma_regularizer,constraint=self.gamma_constraint)
        self.beta = self.add_weight(shape=shape,name='beta',initializer=self.beta_initializer,
                                regularizer=self.beta_regularizer,constraint=self.beta_constraint)
        self.moving_mean = self.add_weight(shape=shape,name='moving_mean',
                        initializer=self.moving_mean_initializer,trainable=False)
        self.moving_variance = self.add_weight(shape=shape,name='moving_variance',
                        initializer=self.moving_variance_initializer,trainable=False)
        self.t = self.add_weight(shape=(1,),name='iteration',
            initializer=tf.keras.initializers.Zeros(),trainable=False)
        self.r_value = self.add_weight(shape=(1,),name='r_value',
            initializer=tf.keras.initializers.Ones(),trainable=False)
        self.d_value = self.add_weight(shape=(1,),name='d_value',
            initializer=tf.keras.initializers.Zeros(),trainable=False)
        self.built = True
        
    def call(self,inputs):
        if tf.keras.backend.learning_phase() in [0,False]:
            return tf.keras.backend.batch_normalization(
                    inputs,
                    self.moving_mean,
                    self.moving_variance,
                    self.beta,
                    self.gamma,
                    axis=self.axis,
                    epsilon=self.epsilon)
            
        mean,variance = tf.nn.moments(inputs,axes=[0,1,2])
        r = tf.sqrt(variance + self.epsilon) / tf.sqrt(self.moving_variance + self.epsilon)
        r = tf.stop_gradient(tf.clip_by_value(r,1/self.r_value,self.r_value))
        d = (mean - self.moving_variance) / tf.sqrt(self.moving_variance + self.epsilon)
        d = tf.stop_gradient(tf.clip_by_value(d,-self.d_value,self.d_value))
        outputs = (inputs - mean) * r / tf.sqrt(variance + self.epsilon) + d
        self.add_update([tf.keras.backend.moving_average_update(self.moving_mean,mean,self.momentum),
                      tf.keras.backend.moving_average_update(self.moving_variance,variance,self.momentum),
                      tf.keras.backend.update_add(self.t,self.t_delta),
                      tf.keras.backend.update(self.r_value, self.rmax/(1+(self.rmax-1)*
                                                                       tf.exp(-self.t)) ),
                      tf.keras.backend.update(self.d_value, self.dmax/(1+ ( self.dmax/1e-3 -1)
                                                                       * tf.exp(-(2*self.t))) )],
                      inputs)
        return outputs*self.gamma + self.beta
    
    def get_config(self):
        config = {
            'momentum':self.momentum,
            'axis':self.axis,
            't_delta':self.t_delta,
            'rmax':self.rmax,
            'dmax':self.dmax,
            'epsilon': self.epsilon,
            'beta_initializer': tf.keras.initializers.serialize(self.beta_initializer),
            'gamma_initializer': tf.keras.initializers.serialize(self.gamma_initializer),
            'beta_regularizer': tf.keras.regularizers.serialize(self.beta_regularizer),
            'gamma_regularizer': tf.keras.regularizers.serialize(self.gamma_regularizer),
            'beta_constraint': tf.keras.constraints.serialize(self.beta_constraint),
            'gamma_constraint':tf.keras. constraints.serialize(self.gamma_constraint),
            'moving_mean_initializer':
                tf.keras.initializers.serialize(self.moving_mean_initializer),
            'moving_variance_initializer':
                tf.keras.initializers.serialize(self.moving_variance_initializer),
        }
        base_config = super(BatchRenormalization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape
class GroupNormalization(tf.keras.layers.Layer):
    def __init__(self,group=32,beta_initializer='zeros',gamma_initializer='ones',
                 beta_regularizer=None,gamma_regularizer=None,
                 beta_constraint=None,gamma_constraint=None,epsilon=1e-5,
                 **kwargs):
        super(GroupNormalization, self).__init__(**kwargs)
        self.group = group
        self.epsilon = epsilon
        self.beta_initializer = tf.keras.initializers.get(beta_initializer)
        self.gamma_initializer = tf.keras.initializers.get(gamma_initializer)
        self.beta_regularizer = tf.keras.regularizers.get(beta_regularizer)
        self.gamma_regularizer = tf.keras.regularizers.get(gamma_regularizer)
        self.beta_constraint = tf.keras.constraints.get(beta_constraint)
        self.gamma_constraint = tf.keras.constraints.get(gamma_constraint)
        
    def build(self, input_shape):  
        assert len(input_shape) == 4
        assert input_shape[-1] >= self.group
        assert input_shape[-1] % self.group == 0
        
        shape = (input_shape[-1],)
        self.gamma = self.add_weight(shape=shape,name='gamma',initializer=self.gamma_initializer,
                                regularizer=self.gamma_regularizer,constraint=self.gamma_constraint)
        self.beta = self.add_weight(shape=shape,name='beta',initializer=self.beta_initializer,
                                regularizer=self.beta_regularizer,constraint=self.beta_constraint)
        self.built = True
        
    def call(self,inputs):
        shape = inputs.shape
        inputs = tf.reshape(inputs, shape=[shape[0], shape[1], shape[2],
                                           self.group, shape[-1]//self.group])
        mean,variance = tf.nn.moments(inputs,axes=[1,2,3],keepdims=True)
        outputs = (inputs - mean) / tf.sqrt(variance + self.epsilon)
        outputs = tf.reshape(outputs,shape=shape)
        return outputs*self.gamma + self.beta
    
    def get_config(self):
        config = {
            'epsilon': self.epsilon,
            'beta_initializer': tf.keras.initializers.serialize(self.beta_initializer),
            'gamma_initializer': tf.keras.initializers.serialize(self.gamma_initializer),
            'beta_regularizer': tf.keras.regularizers.serialize(self.beta_regularizer),
            'gamma_regularizer': tf.keras.regularizers.serialize(self.gamma_regularizer),
            'beta_constraint': tf.keras.constraints.serialize(self.beta_constraint),
            'gamma_constraint': tf.keras.constraints.serialize(self.gamma_constraint)
        }
        base_config = super(GroupNormalization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape        
class SwitchableNormalization(tf.keras.layers.Layer):
    def __init__(self,beta_initializer='zeros',gamma_initializer='ones',
                 mean_weights_initializer='ones',variance_weights_initializer='ones',
                 beta_regularizer=None,gamma_regularizer=None, 
                 mean_weights_regularizer=None,variance_weights_regularizer=None,
                 moving_mean_initializer='zeros',moving_variance_initializer='ones',
                 beta_constraint=None,gamma_constraint=None,epsilon=1e-5,
                 mean_weights_constraint=None,variance_weights_constraint=None,
                 **kwargs):
        super(SwitchableNormalization, self).__init__(**kwargs)
        self.epsilon = epsilon
        self.moving_mean_initializer = tf.keras.initializers.get(moving_mean_initializer)
        self.moving_variance_initializer = tf.keras.initializers.get( moving_variance_initializer)
        self.beta_initializer = tf.keras.initializers.get(beta_initializer)
        self.gamma_initializer = tf.keras.initializers.get(gamma_initializer)
        self.mean_weights_initializer = tf.keras.initializers.get(mean_weights_initializer)
        self.variance_weights_initializer = tf.keras.initializers.get(variance_weights_initializer)
        self.beta_regularizer = tf.keras.regularizers.get(beta_regularizer)
        self.gamma_regularizer = tf.keras.regularizers.get(gamma_regularizer)
        self.mean_weights_regularizer = tf.keras.initializers.get(mean_weights_regularizer)
        self.variance_weights_regularizer = tf.keras.initializers.get(variance_weights_regularizer)
        self.beta_constraint = tf.keras.constraints.get(beta_constraint)
        self.gamma_constraint = tf.keras.constraints.get(gamma_constraint)
        self.mean_weights_constraint = tf.keras.initializers.get(mean_weights_constraint)
        self.variance_weights_constraint = tf.keras.initializers.get(variance_weights_constraint)
        
    def build(self, input_shape):  
        assert len(input_shape)==4
        shape = (input_shape[-1],)
        self.gamma = self.add_weight(shape=shape,name='gamma',initializer=self.gamma_initializer,
                                regularizer=self.gamma_regularizer,constraint=self.gamma_constraint)
        self.beta = self.add_weight(shape=shape,name='beta',initializer=self.beta_initializer,
                                regularizer=self.beta_regularizer,constraint=self.beta_constraint)
        
        self.moving_mean = self.add_weight(
            shape=shape,
            name='moving_mean',
            initializer=self.moving_mean_initializer,
            trainable=False)
        self.moving_variance = self.add_weight(
            shape=shape,
            name='moving_variance',
            initializer=self.moving_variance_initializer,
            trainable=False)
        
        self.mean_weights = self.add_weight(
            shape=(3,),
            name='mean_weights',
            initializer=self.mean_weights_initializer,
            regularizer=self.mean_weights_regularizer,constraint=self.mean_weights_constraint)
        self.variance_weights = self.add_weight(
            shape=(3,),
            name='variance_weights',
            initializer=self.moving_variance_initializer,
            regularizer=self.variance_weights_regularizer,
            constraint=self.variance_weights_constraint)
        
        self.built = True
        
    def call(self,inputs):
        if tf.keras.backend.learning_phase() in [0,False]:
            batch_normolization_mean,batch_normolization_variance = self.moving_mean, self.moving_variance
        else:
            batch_normolization_mean,batch_normolization_variance = tf.nn.moments(inputs,axes=[0,1,2],
                                                                             keepdims=True)
        layer_normolization_mean,layer_normolization_variance = tf.nn.moments(inputs,axes=[1,2,3],
                                                                             keepdims=True)
        instance_normolization_mean,instance_normolization_variance = tf.nn.moments(inputs,axes=[1,2],
                                                                             keepdims=True)     
        weight1 = tf.keras.backend.softmax(self.mean_weights)
        weight2 = tf.keras.backend.softmax(self.variance_weights)
        mean = weight1[0] * batch_normolization_mean + weight1[1] * layer_normolization_mean + weight1[2] * instance_normolization_mean
        variance =  weight2[0] * batch_normolization_variance + weight2[1] * layer_normolization_variance + weight2[2] * instance_normolization_variance
        outputs = (inputs - mean) / tf.sqrt(variance + self.epsilon)
        if tf.keras.backend.learning_phase() in [1,True]:
            self.add_update([K.moving_average_update(self.moving_mean,
                                                     batch_normolization_mean,
                                                     self.momentum),
                        K.moving_average_update(self.moving_variance,
                                                     batch_normolization_variance,
                                                     self.momentum)],
                        inputs)
        return outputs*self.gamma + self.beta
    
    def get_config(self):
        config = {
            'epsilon': self.epsilon,
            'beta_initializer': tf.keras.initializers.serialize(self.beta_initializer),
            'gamma_initializer': tf.keras.initializers.serialize(self.gamma_initializer),
            'beta_regularizer': tf.keras.regularizers.serialize(self.beta_regularizer),
            'gamma_regularizer': tf.keras.regularizers.serialize(self.gamma_regularizer),
            'beta_constraint': tf.keras.constraints.serialize(self.beta_constraint),
            'gamma_constraint': tf.keras.constraints.serialize(self.gamma_constraint),
            'moving_mean_initializer': tf.keras.initializers.serialize(self.beta_initializer),
            'moving_variance_initializer': tf.keras.initializers.serialize(self.gamma_initializer),
            'mean_weights_initializer': tf.keras.initializers.serialize(self.beta_initializer),
            'mean_variance_initializer': tf.keras.initializers.serialize(self.gamma_initializer),
            'mean_weights_regularizer': tf.keras.regularizers.serialize(self.beta_regularizer),
            'mean_variance_regularizer': tf.keras.regularizers.serialize(self.gamma_regularizer),
            'mean_weights_constraint': tf.keras.constraints.serialize(self.beta_constraint),
            'mean_variance_constraint': tf.keras.constraints.serialize(self.gamma_constraint),
        }
        base_config = super(SwitchableNormalization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape
  • 3
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值