"""
坐标注意力机制代码
"""
import numpy as np
from tensorflow.keras.layers import AvgPool2D, Conv2D, BatchNormalization, Activation
from tensorflow.keras import Model
from tensorflow.keras.activations import sigmoid
from tensorflow import split,concat,transpose
import tensorflow as tf
def ac_swish(x):
temp = tf.nn.relu6(x+3) / 6
return x * temp
class Coordinate_Attention(Model):
def __init__(self, w, h, inp, oup, groups=32):
"""
:param w: width
:param h: height
:param inp: input channels
:param oup: output channels
:param groups:
"""
super(Coordinate_Attention, self).__init__()
self.w = w
self.h = h
self.inp = inp
self.oup = oup
self.groups = groups
self.pool_h = AvgPool2D(pool_size=(1, self.w), strides=1, padding='same')
self.pool_w = AvgPool2D(pool_size=(self.h, 1), strides=1, padding='same')
self.mip = max(8, self.inp // self.groups)
self.conv1 = Conv2D(filters=self.mip, kernel_size=(1,1), strides=1, padding='same')
self.bn1 = BatchNormalization()
self.conv2 = Conv2D(filters=self.oup, kernel_size=(1,1), strides=1, padding='same')
self.conv3 = Conv2D(filters=self.oup, kernel_size=(1,1), strides=1, padding='same')
self.ac = Activation(ac_swish)
def call(self, inputs):
residual = inputs
x = residual
# n, c, h, w = x.shape
x_h = self.pool_h(x)
x_w = self.pool_w(x)
x_w = transpose(x_w, [0, 2, 1, 3])
y = concat([x_h, x_w], axis=1)
y = self.conv1(y)
y = self.bn1(y)
y = self.ac(y)
x_h, x_w = split(y, 2, axis=1)
x_w = transpose(x_w, [0, 2, 1, 3])
x_h = sigmoid(self.conv2(x_h))
x_w = sigmoid(self.conv3(x_w))
y = residual * x_w * x_h
print(np.shape(y))
return y
if __name__ == '__main__':
model = Coordinate_Attention(w=14, h=14, inp=512, oup=512)
# inputs = tf.keras.Input(shape=(14,14,512))
# model.call(inputs=inputs)
model.build((1, 14, 14, 512))
model.summary()
坐标注意力机制-Tensorflow2
最新推荐文章于 2024-06-25 17:57:01 发布