"""
坐标注意力机制代码
"""
import numpy as np
from tensorflow.keras.layers import AvgPool2D, Conv2D, BatchNormalization, Activation
from tensorflow.keras import Model
from tensorflow.keras.activations import sigmoid
from tensorflow import split,concat,transpose
import tensorflow as tf
def ac_swish(x):
temp = tf.nn.relu6(x+3) / 6
return x * temp
class Coordinate_Attention(Model):
def __init__(self, w, h, inp, oup, groups=32):
"""
:param w: width
:param h: height
:param inp: input channels
:param oup: output channels
:param groups:
"""
super(Coordinate_Attention, self).__init__()
self.w = w
self.h = h
self.inp = inp
self.oup = oup
self.groups = groups
self.pool_h = AvgPool2D(pool_size=(1, self.w), strides=1, padding='same')
self.pool_w = AvgPool2D(pool_size=(self.h, 1), strides=1, padding='same')
self.mip = max(8, self.inp // self.groups)
self.conv1 = Conv2D(filters=self.mip, kernel_size=(1,1), strides=1, padding='same')
self.bn1 = BatchNormalization()
self.conv2 = Conv2D(filters=self.oup, kernel_size=(1,1), strides=1, padding='same')
self.conv3 = Conv2D(filters=self.oup, kernel_size=(1,1), strides=1, padding='same')
self.ac = Activation(ac_swish)
def call(self, inputs):
residual = inputs
x = residual
# n, c, h, w = x.shape
x_h = self.pool_h(x)
x_w = self.pool_w(x)
x_w = transpose(x_w, [0, 2, 1, 3])
y = concat([x_h, x_w], axis=1)
y = self.conv1(y)
y = self.bn1(y)
y = self.ac(y)
x_h, x_w = split(y, 2, axis=1)
x_w = transpose(x_w, [0, 2, 1, 3])
x_h = sigmoid(self.conv2(x_h))
x_w = sigmoid(self.conv3(x_w))
y = residual * x_w * x_h
print(np.shape(y))
return y
if __name__ == '__main__':
model = Coordinate_Attention(w=14, h=14, inp=512, oup=512)
# inputs = tf.keras.Input(shape=(14,14,512))
# model.call(inputs=inputs)
model.build((1, 14, 14, 512))
model.summary()
坐标注意力机制-Tensorflow2
最新推荐文章于 2025-01-17 14:41:39 发布
本文介绍了一个名为Coordinate Attention的模型,它利用坐标敏感的注意力机制在卷积神经网络中,通过池化操作和sigmoid激活函数,实现对输入特征的空间依赖捕捉。核心代码展示了如何构造并调用该机制,适用于具有特定宽度和高度的图像特征处理。
2366

被折叠的 条评论
为什么被折叠?



