tensorflow实现Local Context Normalization
参考代码:PyTorch implementation for Local Context Normalization: Revisiting Local Normalization
参考文章:Local Context Normalization: Revisiting Local Normalization
代码实现的是torch的code,以及是对2D图像的LCN,笔者改写成了tensorflow 1.4的code以及3D 图像。
Code
import tensorflow as tf
import keras
from keras.layers.core import Layer
import math
import os
import numpy as np
class LocalContextNorm(Layer): # 3D
def __init__(self, channels_per_group=1, window_size=(9, 9, 9), eps=1e-5, img_size = 128):
super(LocalContextNorm, self).__init__()
self.channels_per_group = channels_per_group
self.eps = eps
self.window_size = window_size #[D, H, W]
self.img_size = img_size
def build(self, input_shape):
if len(input_shape) != 5: #[B, H, W, D, C]
raise Exception('Input of LCN layer should have 5 dims with [B, D, H, W, C]!')
if self.img_size <= self.window_size[0]:
raise Exception('Window size must be smaller than image size in LCN case!')
self.num_features = input_shape[-1]
self.weight = tf.Variable(tf.ones([1, 1, 1, 1, self.num_features]), trainable=True)
self.bias = tf.Variable(tf.zeros([1, 1, 1, 1, self.num_features]), trainable=True)
#self.weight = tf.ones([1, 1, 1, 1, self.num_features])
#self.bias = tf.zeros([1, 1, 1, 1, self.num_features])
self.built = True
def call(self, inputs, **kwargs):
inputs = tf.transpose(inputs, [0, 4, 1, 2, 3])
inputs_shape = inputs.shape.as_list()
C = inputs_shape[1]
D, H, W = self.img_size, self.img_size, self.img_size
G = math.floor(C / self.channels_per_group)
assert C % self.channels_per_group == 0
def use_window(inputs):
inputs_sq = inputs * inputs
integral_img = tf.cumsum(tf.cumsum(tf.cumsum(inputs, axis=2), axis=3), axis=4)
integral_img_sq = tf.cumsum(tf.cumsum(tf.cumsum(inputs_sq, axis=2), axis=3), axis=4)
d = [self.window_size[0], self.window_size[1], self.window_size[2]]
kernel = [[[-1, 1], [1, -1]], [[1, -1], [-1, 1]]]
#[[[-1, 1], [1, -1]], [[1, -1], [-1, 1]]]
c_kernel = np.ones((self.channels_per_group, 1, 1)).tolist()
# integral_img
'''sums = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, C, D, H, W], filters=C, kernel_size=2, padding='valid',
kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
data_format='channels_first')(integral_img))'''
sums = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=2, padding='valid',
kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
data_format='channels_first')(tf.expand_dims(integral_img[:, 0, :, :, :], dim=1)))
for i in range(1, C):
temp = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=2, padding='valid',
kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
data_format='channels_first')(tf.expand_dims(integral_img[:, i, :, :, :], dim=1)))
sums = tf.concat([sums, temp], axis=1)
'''integral_img = tf.reshape(integral_img, [-1, 1, C*D, H, W])
d = [self.window_size[0], self.window_size[1], self.window_size[2]]
sums = tf.stop_gradient(
keras.layers.Conv3D(input_shape=[-1, 1, C*D, H, W], filters=1, kernel_size=2, padding='valid',
kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1],
dilation_rate=d,
data_format='channels_first')(integral_img))
sums = tf.reshape(sums, [-1, C, sums.shape.as_list()[2] // C, sums.shape.as_list()[3], sums.shape.as_list()[4]])'''
temp_shape = self.img_size - self.window_size[0]
sums = tf.expand_dims(tf.reshape(sums, [-1, C, temp_shape, temp_shape*temp_shape]), axis=1)
sums = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, C, temp_shape, temp_shape*temp_shape],
filters=1, kernel_size=[self.channels_per_group, 1, 1], padding='valid',
kernel_initializer=keras.initializers.Constant(c_kernel), strides=[self.channels_per_group, 1, 1],
data_format='channels_first')(sums))
assert G == sums.shape.as_list()[2]
sums = tf.squeeze(tf.reshape(sums, [-1, 1, G, temp_shape, temp_shape, temp_shape]), squeeze_dims=1) # [B, G, ., ., .]
# integral_img_sq
'''squares = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, C, D, H, W], filters=C, kernel_size=2, padding='valid',
kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
data_format='channels_first')(integral_img_sq))'''
squares = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=2, padding='valid',
kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
data_format='channels_first')(tf.expand_dims(integral_img_sq[:, 0, :, :, :], dim=1)))
for i in range(1, C):
temp = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=2, padding='valid',
kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
data_format='channels_first')(tf.expand_dims(integral_img_sq[:, i, :, :, :], dim=1)))
squares = tf.concat([squares, temp], axis=1)
temp_squares_shape = self.img_size - self.window_size[0]
squares = tf.expand_dims(tf.reshape(squares, [-1, C, temp_squares_shape, temp_squares_shape*temp_squares_shape]), axis=1)
squares = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, C, temp_squares_shape, temp_squares_shape*temp_squares_shape],
filters=1, kernel_size=[self.channels_per_group, 1, 1], padding='valid',
kernel_initializer=keras.initializers.Constant(c_kernel), strides=[self.channels_per_group, 1, 1],
data_format='channels_first')(squares))
assert G == squares.shape.as_list()[2]
squares = tf.squeeze(tf.reshape(squares, [-1, 1, G, temp_squares_shape, temp_squares_shape, temp_squares_shape]), squeeze_dims=1) # [B, G, ., ., .]
n = self.window_size[0] * self.window_size[1] * self.window_size[2] * self.channels_per_group
means = sums / n
var = 1.0 / n * (squares - sums * sums / n)
d, h, w = temp_shape, temp_shape, temp_shape
pad3d = [int(math.floor((D - d) / 2)), int(math.ceil((D - d) / 2)), int(math.floor((H - h) / 2)),
int(math.ceil((H - h) / 2)), int(math.floor((W - w) / 2)), int(math.ceil((W - w) / 2))]
padded_means = tf.pad(means, [[0, 0], [0, 0], [pad3d[0], pad3d[1]], [pad3d[2], pad3d[3]], [pad3d[4], pad3d[5]]], 'REFLECT')
padded_vars = tf.pad(var, [[0, 0], [0, 0], [pad3d[0], pad3d[1]], [pad3d[2], pad3d[3]], [pad3d[4], pad3d[5]]], 'REFLECT') + self.eps
temp = (inputs[:, :self.channels_per_group, :, :, :] -
tf.expand_dims(padded_means[:, 0, :, :, :], dim=1)) / tf.sqrt((tf.expand_dims(padded_vars[:, 0, :, :, :], dim=1)))
for i in range(1, G):
t_temp = (inputs[:, i * self.channels_per_group:i * self.channels_per_group + self.channels_per_group, :, :, :] -
tf.expand_dims(padded_means[:, i, :, :, :], dim=1)) / tf.sqrt((tf.expand_dims(padded_vars[:, i, :, :, :], dim=1)))
temp = tf.concat([temp, t_temp], axis=1)
inputs = temp
return inputs
#inputs = use_window(inputs) if self.window_size[0] < D else no_use_window(inputs) # 注意这里默认 D, H, W相等,且inputs的各维度也相等
inputs = use_window(inputs) # 注意这里默认 D, H, W相等,且inputs的各维度也相等
inputs = tf.transpose(inputs, [0, 2, 3, 4, 1])
return inputs * self.weight + self.bias
def compute_output_shape(self, input_shape):
return input_shape
class GroupContextNorm(Layer): # 3D
def __init__(self, channels_per_group=1, window_size=(9, 9, 9), eps=1e-5, img_size=128):
super(GroupContextNorm, self).__init__()
self.channels_per_group = channels_per_group
self.eps = eps
self.window_size = window_size #[D, H, W]
self.img_size = img_size
def build(self, input_shape):
if len(input_shape) != 5: #[B, H, W, D, C]
raise Exception('Input of LCN layer should have 5 dims with [B, D, H, W, C]!')
if self.img_size > self.window_size[0]:
raise Exception('Window size must be large than image size in GN case!')
self.num_features = input_shape[-1]
self.weight = tf.Variable(tf.ones([1, 1, 1, 1, self.num_features]), trainable=True)
self.bias = tf.Variable(tf.zeros([1, 1, 1, 1, self.num_features]), trainable=True)
#self.weight = tf.ones([1, 1, 1, 1, self.num_features])
#self.bias = tf.zeros([1, 1, 1, 1, self.num_features])
self.built = True
def call(self, inputs, **kwargs):
inputs = tf.transpose(inputs, [0, 4, 1, 2, 3])
inputs_shape = inputs.shape.as_list()
_, C, D, H, W = inputs_shape[0], inputs_shape[1], inputs_shape[2], \
inputs_shape[3], inputs_shape[4]
G = math.floor(C / self.channels_per_group)
assert C % self.channels_per_group == 0
img_size = self.img_size
def no_use_window(inputs):
inputs_shape = inputs.shape.as_list()
inputs = tf.reshape(inputs, [-1, G, inputs_shape[1] // G * img_size * img_size * img_size])
means, var = tf.nn.moments(inputs, [2], keep_dims=True)
inputs = tf.reshape((inputs - means) / tf.sqrt(var + self.eps), [-1, C, img_size, img_size, img_size])
return inputs
#inputs = use_window(inputs) if self.window_size[0] < D else no_use_window(inputs) # 注意这里默认 D, H, W相等,且inputs的各维度也相等
inputs = no_use_window(inputs) # 注意这里默认 D, H, W相等,且inputs的各维度也相等
inputs = tf.transpose(inputs, [0, 2, 3, 4, 1])
return inputs * self.weight + self.bias
def compute_output_shape(self, input_shape):
return input_shape
if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
#matrix = tf.concat([tf.ones([2, 7, 7, 7, 1]), 2*tf.ones([2, 7, 7, 7, 1]), tf.ones([2, 7, 7, 7, 1]), 2*tf.ones([2, 7, 7, 7, 1])], axis=-1)
matrix = tf.cumsum((tf.cumsum(tf.cumsum(tf.ones([2, 7, 7, 7, 4]), axis=1), axis=2)), axis=3)
lcn_layer = LocalContextNorm(channels_per_group=2, window_size=[3, 3, 3], img_size=matrix.shape.as_list()[1])
matrix_after_lcn = lcn_layer(matrix)
'''gpu_options = tf.GPUOptions(allow_growth=True)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
tf.global_variables_initializer().run(session=sess)'''
#print(sess.run(matrix_after_lcn))
#print(sess.run(matrix_after_lcn[:, :, :, :, 0]))
#print(sess.run(matrix_after_lcn[:, :, :, :, 1]))
#print(matrix_after_lcn.shape.as_list())
Some Notes
1、keras 报错 ‘_TfDeviceCaptureOp’ object has no attribute ‘type’,tf1.4对应keras2.0.8,重装keras=2.0.8后解决。
2、tf.cond无法完成建图,故上述代码写了两个class。
3、tf.nn.layer以及tflearn.layers中在tf1.4版本下的conv3d没有dilation_rate参数,故使用keras.layer.Conv3D。
4、二维积分图运算在参考代码中使用了三维卷积,故三维积分图需要使用四维卷积,但是没有3维卷积以上API,故采用Conv3D + tf.concat的策略。
5、三维积分图运算矩阵为 kernel = [[[-1, 1], [1, -1]], [[1, -1], [-1, 1]]],可自行推导,二维积分图运算矩阵参考上述参考代码。
6、tf.pad无法完成比input size还大的padding,故use_window的window_size最大为img_size一半; ‘REFLECT’ padding方式不会重复最外面的那层,例如[-1, -1, -1]其实是复制的[1, 1, 1]的值,而不是[0, 0, 0]的值。
########################################################################
新的实现:
def Gaussian_mol(k, sigma):
h = 2*k
w = 2*k
d = 2*k
A = np.zeros((h, w, d))
for i in range(h):
for j in range(w):
for m in range(d):
fenzi = (i+1-k-1)**2 + (j+1-k-1)**2 + (m+1-k-1)**2
#print(fenzi)
A[i, j, m]=np.exp(-fenzi / (2*sigma**2)) / (math.sqrt((2*np.pi)**3) * sigma**3)
#print(A)
A = A/A[0][0]
#print(A/A[0][0])
for i in range(h):
for j in range(w):
for m in range(d):
A[i, j, m]=round(A[i, j, m])
#print(A)
A = A/A.sum()
return A
LCN v1
class LocalContextNorm_v1(Layer): # 3D
def __init__(self, channels_per_group=1, window_size=(32, 32, 32), eps=1e-5, img_size = 128, Gaussian_sigma=10):
super(LocalContextNorm_v1, self).__init__()
self.channels_per_group = channels_per_group
self.eps = eps
self.window_size = window_size #[D, H, W]
self.img_size = img_size
self.sigma = Gaussian_sigma
def build(self, input_shape):
if len(input_shape) != 5: #[B, H, W, D, C]
raise Exception('Input of LCN layer should have 5 dims with [B, D, H, W, C]!')
if self.channels_per_group != 1:
raise Exception('Channels per group must be 1!')
self.num_features = input_shape[-1]
self.weight = tf.Variable(tf.ones([1, 1, 1, 1, self.num_features]), trainable=True)
self.bias = tf.Variable(tf.zeros([1, 1, 1, 1, self.num_features]), trainable=True)
#self.weight = tf.ones([1, 1, 1, 1, self.num_features])
#self.bias = tf.zeros([1, 1, 1, 1, self.num_features])
Gaussian_weight = Gaussian_mol(self.window_size[0] // 2, self.sigma) * 255
#Gaussian_weight = Gaussian_mol(self.window_size[0] // 2, self.sigma)
self.Gaussian_weight = Gaussian_weight
self.sq_weight = self.Gaussian_weight * self.Gaussian_weight
self.built = True
def call(self, inputs, **kwargs):
inputs = tf.transpose(inputs, [0, 4, 1, 2, 3]) #[B, C, H, W, D]
D, H, W = self.img_size, self.img_size, self.img_size
sq_inputs = inputs * inputs
#使用Conv3D + padding
sums = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=self.window_size, padding='valid',
kernel_initializer=keras.initializers.Constant(self.Gaussian_weight), strides=[1, 1, 1],
data_format='channels_first')(inputs))
sq_sums = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=self.window_size, padding='valid',
kernel_initializer=keras.initializers.Constant(self.sq_weight), strides=[1, 1, 1],
data_format='channels_first')(sq_inputs))
n = self.window_size[0] * self.window_size[1] * self.window_size[2]
means = sums / n
vars = 1.0 / n * (sq_sums - sums * sums / n)
pad1 = (self.window_size[0] - 1) // 2
pad2 = (self.window_size[0] - 1) // 2 + 1
pad_mean = tf.pad(means, [[0, 0], [0, 0], [pad1, pad2], [pad1, pad2], [pad1, pad2]], 'REFLECT')
pad_var = tf.pad(vars, [[0, 0], [0, 0], [pad1, pad2], [pad1, pad2], [pad1, pad2]], 'REFLECT') + self.eps
new_inputs = (inputs - pad_mean) / tf.sqrt(pad_var)
new_inputs = tf.transpose(new_inputs, [0, 2, 3, 4, 1])
return new_inputs * self.weight + self.bias
def compute_output_shape(self, input_shape):
return input_shape
LCN v2
class LocalContextNorm_v2(Layer): # 3D
def __init__(self, channels_per_group=1, window_size=(16, 16, 16), eps=1e-5, img_size=128, Gaussian_sigma=30):
super(LocalContextNorm_v2, self).__init__()
self.channels_per_group = channels_per_group
self.eps = eps
self.window_size = window_size #[D, H, W]
self.img_size = img_size
self.sigma = Gaussian_sigma
def build(self, input_shape):
if len(input_shape) != 5: #[B, H, W, D, C]
raise Exception('Input of LCN layer should have 5 dims with [B, D, H, W, C]!')
if self.channels_per_group != 1:
raise Exception('Channels per group must be 1!')
self.num_features = input_shape[-1]
self.weight = tf.Variable(tf.ones([1, 1, 1, 1, self.num_features]), trainable=True)
self.bias = tf.Variable(tf.zeros([1, 1, 1, 1, self.num_features]), trainable=True)
#self.weight = tf.ones([1, 1, 1, 1, self.num_features])
#self.bias = tf.zeros([1, 1, 1, 1, self.num_features])
Gaussian_weight = Gaussian_mol(self.window_size[0] // 2, self.sigma) * 255
self.Gaussian_weight = Gaussian_weight
self.sq_weight = self.Gaussian_weight * self.Gaussian_weight
self.stack_weight = np.expand_dims(np.expand_dims(self.Gaussian_weight, axis=0), axis=0)
temp = self.stack_weight
for i in range(1, self.img_size // self.window_size[0]):
self.stack_weight = np.concatenate([self.stack_weight, temp], axis=2)
temp = self.stack_weight
for j in range(1, self.img_size // self.window_size[1]):
self.stack_weight = np.concatenate([self.stack_weight, temp], axis=3)
temp = self.stack_weight
for k in range(1, self.img_size // self.window_size[2]):
self.stack_weight = np.concatenate([self.stack_weight, temp], axis=4)
self.built = True
def call(self, inputs, **kwargs):
inputs = tf.transpose(inputs, [0, 4, 1, 2, 3]) #[B, C, H, W, D]
D, H, W = self.img_size, self.img_size, self.img_size
sq_inputs = inputs * inputs
# Conv3D + UpSampling3D
sums = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=self.window_size, padding='valid',
kernel_initializer=keras.initializers.Constant(self.Gaussian_weight), strides=[self.window_size[0], self.window_size[1], self.window_size[2]],
data_format='channels_first')(inputs))
sq_sums = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=self.window_size, padding='valid',
kernel_initializer=keras.initializers.Constant(self.sq_weight), strides=[self.window_size[0], self.window_size[1], self.window_size[2]],
data_format='channels_first')(sq_inputs))
n = self.window_size[0] * self.window_size[1] * self.window_size[2]
means = sums / n
vars = 1.0 / n * (sq_sums - sums * sums / n) + self.eps
pad_mean = tf.stop_gradient(UpSampling3D(size=(self.window_size[0], self.window_size[1], self.window_size[2]), data_format='channels_first')(means))
pad_var = tf.stop_gradient(UpSampling3D(size=(self.window_size[0], self.window_size[1], self.window_size[2]), data_format='channels_first')(vars))
weight_inputs = inputs * tf.constant(self.stack_weight, dtype=tf.float32)
new_inputs = (weight_inputs - pad_mean) / tf.sqrt(pad_var)
new_inputs = tf.transpose(new_inputs, [0, 2, 3, 4, 1])
new_inputs = new_inputs - tf.reduce_min(new_inputs)
new_inputs = new_inputs / tf.reduce_max(new_inputs)
return new_inputs * self.weight + self.bias
def compute_output_shape(self, input_shape):
return input_shape
Some Notes
1、周报1220中是对原始的image(range 0- 255)进行的初步实现和调整窗框和sigma,但是网络中是0-1的range,由于在计算中设计平方操作,所以不是线性的关系了,故在上述实现中将gaussian weight乘以255,以match1220中的结果,check后实现正确。
2、由于要加入gaussian weight,没法再用积分图,故采用平凡的3D卷积API进行实现,详情见上述代码。