tensorflow实现Local Context Normalization

tensorflow实现Local Context Normalization

参考代码:PyTorch implementation for Local Context Normalization: Revisiting Local Normalization
参考文章:Local Context Normalization: Revisiting Local Normalization

代码实现的是torch的code,以及是对2D图像的LCN,笔者改写成了tensorflow 1.4的code以及3D 图像。

Code

import tensorflow as tf
import keras
from keras.layers.core import Layer
import math
import os
import numpy as np


class LocalContextNorm(Layer):  # 3D
    def __init__(self, channels_per_group=1, window_size=(9, 9, 9), eps=1e-5, img_size = 128):
        super(LocalContextNorm, self).__init__()
        self.channels_per_group = channels_per_group
        self.eps = eps
        self.window_size = window_size  #[D, H, W]
        self.img_size = img_size

    def build(self, input_shape):

        if len(input_shape) != 5: #[B, H, W, D, C]
            raise Exception('Input of LCN layer should have 5 dims with [B, D, H, W, C]!')

        if self.img_size <= self.window_size[0]:
            raise Exception('Window size must be smaller than image size in LCN case!')

        self.num_features = input_shape[-1]
        self.weight = tf.Variable(tf.ones([1, 1, 1, 1, self.num_features]), trainable=True)
        self.bias = tf.Variable(tf.zeros([1, 1, 1, 1, self.num_features]), trainable=True)
        #self.weight = tf.ones([1, 1, 1, 1, self.num_features])
        #self.bias = tf.zeros([1, 1, 1, 1, self.num_features])

        self.built = True

    def call(self, inputs, **kwargs):

        inputs = tf.transpose(inputs, [0, 4, 1, 2, 3])
        inputs_shape = inputs.shape.as_list()
        C = inputs_shape[1]
        D, H, W = self.img_size, self.img_size, self.img_size

        G = math.floor(C / self.channels_per_group)

        assert C % self.channels_per_group == 0

        def use_window(inputs):

            inputs_sq = inputs * inputs
            integral_img = tf.cumsum(tf.cumsum(tf.cumsum(inputs, axis=2), axis=3), axis=4)
            integral_img_sq = tf.cumsum(tf.cumsum(tf.cumsum(inputs_sq, axis=2), axis=3), axis=4)

            d = [self.window_size[0], self.window_size[1], self.window_size[2]]
            kernel = [[[-1, 1], [1, -1]], [[1, -1], [-1, 1]]]
                     #[[[-1, 1], [1, -1]], [[1, -1], [-1, 1]]]
            c_kernel = np.ones((self.channels_per_group, 1, 1)).tolist()

            # integral_img
            '''sums = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, C, D, H, W], filters=C, kernel_size=2, padding='valid',
                                                        kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
                                                        data_format='channels_first')(integral_img))'''
            sums = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=2, padding='valid',
                                                            kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
                                                        data_format='channels_first')(tf.expand_dims(integral_img[:, 0, :, :, :], dim=1)))
            for i in range(1, C):
                temp = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=2, padding='valid',
                                                            kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
                                                        data_format='channels_first')(tf.expand_dims(integral_img[:, i, :, :, :], dim=1)))
                sums = tf.concat([sums, temp], axis=1)

            '''integral_img = tf.reshape(integral_img, [-1, 1, C*D, H, W])
            d = [self.window_size[0], self.window_size[1], self.window_size[2]]
            sums = tf.stop_gradient(
                keras.layers.Conv3D(input_shape=[-1, 1, C*D, H, W], filters=1, kernel_size=2, padding='valid',
                                    kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1],
                                    dilation_rate=d,
                                    data_format='channels_first')(integral_img))
            sums = tf.reshape(sums, [-1, C, sums.shape.as_list()[2] // C, sums.shape.as_list()[3], sums.shape.as_list()[4]])'''

            temp_shape = self.img_size - self.window_size[0]
            sums = tf.expand_dims(tf.reshape(sums, [-1, C, temp_shape, temp_shape*temp_shape]), axis=1)
            sums = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, C, temp_shape, temp_shape*temp_shape],
                                                        filters=1, kernel_size=[self.channels_per_group, 1, 1], padding='valid',
                                                        kernel_initializer=keras.initializers.Constant(c_kernel), strides=[self.channels_per_group, 1, 1],
                                                        data_format='channels_first')(sums))
            assert  G == sums.shape.as_list()[2]
            sums = tf.squeeze(tf.reshape(sums, [-1, 1, G, temp_shape, temp_shape, temp_shape]), squeeze_dims=1) # [B, G, ., ., .]

            # integral_img_sq
            '''squares = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, C, D, H, W], filters=C, kernel_size=2, padding='valid',
                                                        kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
                                                        data_format='channels_first')(integral_img_sq))'''

            squares = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=2, padding='valid',
                                                        kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
                                                        data_format='channels_first')(tf.expand_dims(integral_img_sq[:, 0, :, :, :], dim=1)))
            for i in range(1, C):
                temp = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=2, padding='valid',
                                                        kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
                                                        data_format='channels_first')(tf.expand_dims(integral_img_sq[:, i, :, :, :], dim=1)))
                squares = tf.concat([squares, temp], axis=1)

            temp_squares_shape = self.img_size - self.window_size[0]
            squares = tf.expand_dims(tf.reshape(squares, [-1, C, temp_squares_shape, temp_squares_shape*temp_squares_shape]), axis=1)
            squares = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, C, temp_squares_shape, temp_squares_shape*temp_squares_shape],
                                                        filters=1, kernel_size=[self.channels_per_group, 1, 1], padding='valid',
                                                        kernel_initializer=keras.initializers.Constant(c_kernel), strides=[self.channels_per_group, 1, 1],
                                                        data_format='channels_first')(squares))
            assert  G == squares.shape.as_list()[2]
            squares = tf.squeeze(tf.reshape(squares, [-1, 1, G, temp_squares_shape, temp_squares_shape, temp_squares_shape]), squeeze_dims=1) # [B, G, ., ., .]


            n = self.window_size[0] * self.window_size[1] * self.window_size[2] * self.channels_per_group
            means = sums / n
            var = 1.0 / n * (squares - sums * sums / n)
            d, h, w = temp_shape, temp_shape, temp_shape

            pad3d = [int(math.floor((D - d) / 2)), int(math.ceil((D - d) / 2)), int(math.floor((H - h) / 2)),
                     int(math.ceil((H - h) / 2)), int(math.floor((W - w) / 2)), int(math.ceil((W - w) / 2))]
            padded_means = tf.pad(means, [[0, 0], [0, 0], [pad3d[0], pad3d[1]], [pad3d[2], pad3d[3]], [pad3d[4], pad3d[5]]], 'REFLECT')
            padded_vars = tf.pad(var, [[0, 0], [0, 0], [pad3d[0], pad3d[1]], [pad3d[2], pad3d[3]], [pad3d[4], pad3d[5]]], 'REFLECT') + self.eps

            temp = (inputs[:, :self.channels_per_group, :, :, :] -
                     tf.expand_dims(padded_means[:, 0, :, :, :], dim=1)) / tf.sqrt((tf.expand_dims(padded_vars[:, 0, :, :, :], dim=1)))
            for i in range(1, G):
                t_temp = (inputs[:, i * self.channels_per_group:i * self.channels_per_group + self.channels_per_group, :, :, :] -
                     tf.expand_dims(padded_means[:, i, :, :, :], dim=1)) / tf.sqrt((tf.expand_dims(padded_vars[:, i, :, :, :], dim=1)))
                temp = tf.concat([temp, t_temp], axis=1)

            inputs = temp
            return inputs

        #inputs = use_window(inputs) if self.window_size[0] < D else no_use_window(inputs)  # 注意这里默认 D, H, W相等,且inputs的各维度也相等
        inputs = use_window(inputs)  # 注意这里默认 D, H, W相等,且inputs的各维度也相等
        inputs =  tf.transpose(inputs, [0, 2, 3, 4, 1])

        return inputs * self.weight + self.bias


    def compute_output_shape(self, input_shape):
        return input_shape

class GroupContextNorm(Layer):  # 3D
    def __init__(self, channels_per_group=1, window_size=(9, 9, 9), eps=1e-5, img_size=128):
        super(GroupContextNorm, self).__init__()
        self.channels_per_group = channels_per_group
        self.eps = eps
        self.window_size = window_size  #[D, H, W]
        self.img_size = img_size

    def build(self, input_shape):

        if len(input_shape) != 5: #[B, H, W, D, C]
            raise Exception('Input of LCN layer should have 5 dims with [B, D, H, W, C]!')

        if self.img_size > self.window_size[0]:
            raise Exception('Window size must be large than image size in GN case!')

        self.num_features = input_shape[-1]
        self.weight = tf.Variable(tf.ones([1, 1, 1, 1, self.num_features]), trainable=True)
        self.bias = tf.Variable(tf.zeros([1, 1, 1, 1, self.num_features]), trainable=True)
        #self.weight = tf.ones([1, 1, 1, 1, self.num_features])
        #self.bias = tf.zeros([1, 1, 1, 1, self.num_features])

        self.built = True

    def call(self, inputs, **kwargs):

        inputs = tf.transpose(inputs, [0, 4, 1, 2, 3])
        inputs_shape = inputs.shape.as_list()
        _, C, D, H, W  = inputs_shape[0], inputs_shape[1], inputs_shape[2], \
                        inputs_shape[3], inputs_shape[4]
        G = math.floor(C / self.channels_per_group)

        assert C % self.channels_per_group == 0

        img_size = self.img_size

        def no_use_window(inputs):
            inputs_shape = inputs.shape.as_list()
            inputs = tf.reshape(inputs, [-1, G, inputs_shape[1] // G * img_size * img_size * img_size])
            means, var = tf.nn.moments(inputs, [2], keep_dims=True)
            inputs = tf.reshape((inputs - means) / tf.sqrt(var + self.eps), [-1, C, img_size, img_size, img_size])

            return inputs

        #inputs = use_window(inputs) if self.window_size[0] < D else no_use_window(inputs)  # 注意这里默认 D, H, W相等,且inputs的各维度也相等
        inputs = no_use_window(inputs)  # 注意这里默认 D, H, W相等,且inputs的各维度也相等
        inputs =  tf.transpose(inputs, [0, 2, 3, 4, 1])

        return inputs * self.weight + self.bias
        
    def compute_output_shape(self, input_shape):
        return input_shape

if __name__ == '__main__':

    os.environ['CUDA_VISIBLE_DEVICES'] = '1'

    #matrix = tf.concat([tf.ones([2, 7, 7, 7, 1]), 2*tf.ones([2, 7, 7, 7, 1]), tf.ones([2, 7, 7, 7, 1]), 2*tf.ones([2, 7, 7, 7, 1])], axis=-1)
    matrix = tf.cumsum((tf.cumsum(tf.cumsum(tf.ones([2, 7, 7, 7, 4]), axis=1), axis=2)), axis=3)
    lcn_layer = LocalContextNorm(channels_per_group=2, window_size=[3, 3, 3], img_size=matrix.shape.as_list()[1])

    matrix_after_lcn = lcn_layer(matrix)

    '''gpu_options = tf.GPUOptions(allow_growth=True)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    tf.global_variables_initializer().run(session=sess)'''

    #print(sess.run(matrix_after_lcn))
    #print(sess.run(matrix_after_lcn[:, :, :, :, 0]))
    #print(sess.run(matrix_after_lcn[:, :, :, :, 1]))
    #print(matrix_after_lcn.shape.as_list())

Some Notes

1、keras 报错 ‘_TfDeviceCaptureOp’ object has no attribute ‘type’,tf1.4对应keras2.0.8,重装keras=2.0.8后解决。
2、tf.cond无法完成建图,故上述代码写了两个class。
3、tf.nn.layer以及tflearn.layers中在tf1.4版本下的conv3d没有dilation_rate参数,故使用keras.layer.Conv3D。
4、二维积分图运算在参考代码中使用了三维卷积,故三维积分图需要使用四维卷积,但是没有3维卷积以上API,故采用Conv3D + tf.concat的策略。
5、三维积分图运算矩阵为 kernel = [[[-1, 1], [1, -1]], [[1, -1], [-1, 1]]],可自行推导,二维积分图运算矩阵参考上述参考代码。
6、tf.pad无法完成比input size还大的padding,故use_window的window_size最大为img_size一半; ‘REFLECT’ padding方式不会重复最外面的那层,例如[-1, -1, -1]其实是复制的[1, 1, 1]的值,而不是[0, 0, 0]的值。

########################################################################
新的实现:

def Gaussian_mol(k, sigma):
    h = 2*k
    w = 2*k
    d = 2*k
    A = np.zeros((h, w, d))
    for i in range(h):
        for j in range(w):
            for m in range(d):
                fenzi = (i+1-k-1)**2 + (j+1-k-1)**2 + (m+1-k-1)**2
                #print(fenzi)
                A[i, j, m]=np.exp(-fenzi / (2*sigma**2)) / (math.sqrt((2*np.pi)**3) * sigma**3)
    #print(A)
    A = A/A[0][0]
    #print(A/A[0][0])
    for i in range(h):
        for j in range(w):
            for m in range(d):
                A[i, j, m]=round(A[i, j, m])
    #print(A)
    A = A/A.sum()
    return A

LCN v1

class LocalContextNorm_v1(Layer):  # 3D
    def __init__(self, channels_per_group=1, window_size=(32, 32, 32), eps=1e-5, img_size = 128, Gaussian_sigma=10):
        super(LocalContextNorm_v1, self).__init__()
        self.channels_per_group = channels_per_group
        self.eps = eps
        self.window_size = window_size  #[D, H, W]
        self.img_size = img_size
        self.sigma = Gaussian_sigma

    def build(self, input_shape):

        if len(input_shape) != 5: #[B, H, W, D, C]
            raise Exception('Input of LCN layer should have 5 dims with [B, D, H, W, C]!')

        if self.channels_per_group != 1:
            raise Exception('Channels per group must be 1!')

        self.num_features = input_shape[-1]
        self.weight = tf.Variable(tf.ones([1, 1, 1, 1, self.num_features]), trainable=True)
        self.bias = tf.Variable(tf.zeros([1, 1, 1, 1, self.num_features]), trainable=True)
        #self.weight = tf.ones([1, 1, 1, 1, self.num_features])
        #self.bias = tf.zeros([1, 1, 1, 1, self.num_features])

        Gaussian_weight = Gaussian_mol(self.window_size[0] // 2, self.sigma) * 255
        #Gaussian_weight = Gaussian_mol(self.window_size[0] // 2, self.sigma)
        self.Gaussian_weight = Gaussian_weight
        self.sq_weight = self.Gaussian_weight * self.Gaussian_weight

        self.built = True

    def call(self, inputs, **kwargs):

        inputs = tf.transpose(inputs, [0, 4, 1, 2, 3]) #[B, C, H, W, D]
        D, H, W = self.img_size, self.img_size, self.img_size
        sq_inputs = inputs * inputs

        #使用Conv3D + padding
        sums = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=self.window_size, padding='valid',
                                kernel_initializer=keras.initializers.Constant(self.Gaussian_weight), strides=[1, 1, 1],
                                                    data_format='channels_first')(inputs))
        sq_sums = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=self.window_size, padding='valid',
                                kernel_initializer=keras.initializers.Constant(self.sq_weight), strides=[1, 1, 1],
                                                    data_format='channels_first')(sq_inputs))

        n = self.window_size[0] * self.window_size[1] * self.window_size[2]
        means = sums / n
        vars = 1.0 / n * (sq_sums - sums * sums / n)

        pad1 = (self.window_size[0] - 1) // 2
        pad2 = (self.window_size[0] - 1) // 2 + 1

        pad_mean = tf.pad(means, [[0, 0], [0, 0], [pad1, pad2], [pad1, pad2], [pad1, pad2]], 'REFLECT')
        pad_var = tf.pad(vars, [[0, 0], [0, 0], [pad1, pad2], [pad1, pad2], [pad1, pad2]], 'REFLECT') + self.eps
        new_inputs = (inputs - pad_mean) / tf.sqrt(pad_var)

        new_inputs = tf.transpose(new_inputs, [0, 2, 3, 4, 1])

        return new_inputs * self.weight + self.bias

    def compute_output_shape(self, input_shape):
        return input_shape

LCN v2

class LocalContextNorm_v2(Layer):  # 3D
    def __init__(self, channels_per_group=1, window_size=(16, 16, 16), eps=1e-5, img_size=128, Gaussian_sigma=30):
        super(LocalContextNorm_v2, self).__init__()
        self.channels_per_group = channels_per_group
        self.eps = eps
        self.window_size = window_size  #[D, H, W]
        self.img_size = img_size
        self.sigma = Gaussian_sigma

    def build(self, input_shape):

        if len(input_shape) != 5: #[B, H, W, D, C]
            raise Exception('Input of LCN layer should have 5 dims with [B, D, H, W, C]!')

        if self.channels_per_group != 1:
            raise Exception('Channels per group must be 1!')

        self.num_features = input_shape[-1]
        self.weight = tf.Variable(tf.ones([1, 1, 1, 1, self.num_features]), trainable=True)
        self.bias = tf.Variable(tf.zeros([1, 1, 1, 1, self.num_features]), trainable=True)
        #self.weight = tf.ones([1, 1, 1, 1, self.num_features])
        #self.bias = tf.zeros([1, 1, 1, 1, self.num_features])

        Gaussian_weight = Gaussian_mol(self.window_size[0] // 2, self.sigma) * 255
        self.Gaussian_weight = Gaussian_weight
        self.sq_weight = self.Gaussian_weight * self.Gaussian_weight

        self.stack_weight = np.expand_dims(np.expand_dims(self.Gaussian_weight, axis=0), axis=0)
        temp = self.stack_weight
        for i in range(1, self.img_size // self.window_size[0]):
            self.stack_weight = np.concatenate([self.stack_weight, temp], axis=2)
        temp = self.stack_weight
        for j in range(1, self.img_size // self.window_size[1]):
            self.stack_weight = np.concatenate([self.stack_weight, temp], axis=3)
        temp = self.stack_weight
        for k in range(1, self.img_size // self.window_size[2]):
            self.stack_weight = np.concatenate([self.stack_weight, temp], axis=4)

        self.built = True

    def call(self, inputs, **kwargs):

        inputs = tf.transpose(inputs, [0, 4, 1, 2, 3]) #[B, C, H, W, D]
        D, H, W = self.img_size, self.img_size, self.img_size
        sq_inputs = inputs * inputs



        # Conv3D + UpSampling3D
        sums = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=self.window_size, padding='valid',
                                kernel_initializer=keras.initializers.Constant(self.Gaussian_weight), strides=[self.window_size[0], self.window_size[1], self.window_size[2]],
                                                    data_format='channels_first')(inputs))
        sq_sums = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=self.window_size, padding='valid',
                                kernel_initializer=keras.initializers.Constant(self.sq_weight), strides=[self.window_size[0], self.window_size[1], self.window_size[2]],
                                                    data_format='channels_first')(sq_inputs))

        n = self.window_size[0] * self.window_size[1] * self.window_size[2]
        means = sums / n
        vars = 1.0 / n * (sq_sums - sums * sums / n) + self.eps

        pad_mean = tf.stop_gradient(UpSampling3D(size=(self.window_size[0], self.window_size[1], self.window_size[2]), data_format='channels_first')(means))
        pad_var = tf.stop_gradient(UpSampling3D(size=(self.window_size[0], self.window_size[1], self.window_size[2]), data_format='channels_first')(vars))

        weight_inputs = inputs * tf.constant(self.stack_weight, dtype=tf.float32)

        new_inputs = (weight_inputs - pad_mean) / tf.sqrt(pad_var)
        new_inputs = tf.transpose(new_inputs, [0, 2, 3, 4, 1])

        new_inputs = new_inputs - tf.reduce_min(new_inputs)
        new_inputs = new_inputs / tf.reduce_max(new_inputs)

        return new_inputs * self.weight + self.bias

    def compute_output_shape(self, input_shape):
        return input_shape

Some Notes

1、周报1220中是对原始的image(range 0- 255)进行的初步实现和调整窗框和sigma,但是网络中是0-1的range,由于在计算中设计平方操作,所以不是线性的关系了,故在上述实现中将gaussian weight乘以255,以match1220中的结果,check后实现正确。
2、由于要加入gaussian weight,没法再用积分图,故采用平凡的3D卷积API进行实现,详情见上述代码。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值