Criss Cross Attention 模块的tensorflow代码实现
也是边学习边写代码,如有问题和指正,请联系!!!
模块结构
Affinity 操作
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow import keras
class criss_cross_attention_Affinity(tf.keras.layers.Layer):
def __init__(self, axis=1, **kwargs):
super(criss_cross_attention_Affinity, self).__init__(**kwargs)
self.axis = axis
def call(self, x):
batch_size, H, W, Channel = x.shape
outputs = []
for i in range(H):
for j in range(W):
ver = x[:, i, j, :]
temp_x = tf.concat([x[:, i, 0:j, :], x[:, i, j + 1:W, :], x[:, :, j, :]], axis=1)
trans_temp = tf.matmul(temp_x, tf.expand_dims(ver, -1))
trans_temp = tf.squeeze(trans_temp, -1)
trans_temp = tf.expand_dims(trans_temp, axis=1)
outputs.append(trans_temp)
outputs = layers.Concatenate(axis=self.axis)(outputs)
C = outputs.shape[2]
outputs = tf.reshape(outputs, [-1, H, W, C])
return outputs
def get_config(self):
config = {'axis': self.axis}
base_config = super(criss_cross_attention_Affinity, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
Aggregation操作
class criss_cross_attention_Aggregation(tf.keras.layers.Layer):
def __init__(self, axis=1, **kwargs):
super(criss_cross_attention_Aggregation, self).__init__(**kwargs)
self.axis = axis
def call(self, x, Affinity):
batch_size, H, W, Channel = x.shape
Affinity = layers.Activation('softmax')(Affinity)
outputs = []
for i in range(H):
for j in range(W):
ver = Affinity[:, i, j, :]
temp_x = tf.concat([x[:, i, 0:j, :], x[:, i, j + 1:W, :], x[:, :, j, :]], axis=1)
trans_temp = tf.matmul(tf.transpose(tf.expand_dims(ver, -1), [0, 2, 1]), temp_x)
outputs.append(trans_temp)
outputs = layers.Concatenate(axis=self.axis)(outputs)
C = outputs.shape[2]
outputs = tf.reshape(outputs, [-1, H, W, C])
return outputs
def get_config(self):
config = {'axis': self.axis}
base_config = super(criss_cross_attention_Aggregation, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
合并两个操作
def criss_cross_attention(x):
x = layers.Conv2D(filters=64, kernel_size=3, padding='same', strides=2)(x)
x_origin = x
affinity = criss_cross_attention_Affinity(1)(x)
out = criss_cross_attention_Aggregation(1)(x, affinity)
out = layers.Add()([out, x_origin])
out = layers.UpSampling2D(size=2, interpolation='bilinear')(out)
return out
模型打印结果
问题
由于该模块是针对每一个像素点在原特征图上对应像素所在的十字行列像素上进行计算,所以代码写的是循环遍历每一个像素。导致计算复杂,暂时还没能解决这个问题。