文章目录
前言
研究生阶段的一些工作、因为涉及到了注意力方面的研究,所以复现了一些比较出名的注意力模块,这些都是我和朋友根据自己理解复现的,用的是keras,不保证复现的正确性,欢迎交流。
1. ECA
https://blog.csdn.net/qq_35054151/article/details/115434812import math
from keras.layers import *
from keras.layers import Activation
from keras.layers import GlobalAveragePooling2D
import keras.backend as K
import tensorflow as tf
def eca_layer(inputs_tensor=None,num=None,gamma=2,b=1):
"""
注意力模块-NET
:param inputs_tensor: input_tensor.shape=[batchsize,h,w,channels]
:param num:
:param gamma:
:param b:
:return:
"""
channels = K.int_shape(inputs_tensor)[-1]
t = int(abs((math.log(channels,2)+b)/gamma))
k = t if t%2 else t+1
x_global_avg_pool = GlobalAveragePooling2D()(inputs_tensor)
x = Reshape((channels,1))(x_global_avg_pool)
x = Conv1D(1, kernel_size=k,padding="same",name="eca_conv1_" + str(num))(x)
x = Activation('sigmoid', name='eca_conv1_relu_' + str(num))(x) #shape=[batch,chnnels,1]
x = Reshape((1, 1, channels))(x)
output = multiply([inputs_tensor,x])
return output
2. Coordinate attention
import tensorflow as tf
from keras.layers import Lambda,Concatenate,Reshape,Conv2D,BatchNormalization,Activation,Multiply,Add
def coordinate(inputs,ratio=2, name="name"):
W,H,C = [int(x) for x in inputs.shape[1:]]
temp_dim = max(int(C//ratio),ratio)
H_pool = Lambda(lambda x: tf.reduce_mean(x, axis=1))(inputs)
W_pool = Lambda(lambda x: tf.reduce_mean(x, axis=2))(inputs)
x = Concatenate(axis=1)([H_pool,W_pool])
x = Reshape((1,W+H,C))(x)
x = Conv2D(temp_dim,1, name=name+'1')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x_h,x_w = Lambda(lambda x:tf.split(x,[H,W],axis=2))(x)
x_w = Reshape((W,1,temp_dim))(x_w)
x_h = Conv2D(C,1,activation='sigmoid',name=name+"2")(x_h)
x_w = Conv2D(C, 1, activation='sigmoid',name=name+"3")(x_w)
x = Multiply()([inputs,x_h,x_w])
x = Add()([inputs,x])
return x
3. Dual attention
import keras
from keras.layers import Activation, Conv2D
import keras.backend as K
import tensorflow as tf
from keras.layers import Layer
# 位置注意
class PAM(Layer):
def __init__(self,
# beta_initializer=tf.zeros_initializer()
beta_initializer=keras.initializers.Zeros(),
beta_regularizer=None,
beta_constraint=None,
kernal_initializer='he_normal',
kernal_regularizer=None,
kernal_constraint=None,
**kwargs):
super(PAM, self).__init__(**kwargs)
self.beta_initializer = beta_initializer
self.beta_regularizer = beta_regularizer
self.beta_constraint = beta_constraint
self.kernal_initializer = kernal_initializer
self.kernal_regularizer = kernal_regularizer
self.kernal_constraint = kernal_constraint
def build(self, input_shape):
_, h, w, filters = input_shape
self.beta = self.add_weight(shape=(1,),
initializer=self.beta_initializer,
name='beta',
regularizer=self.beta_regularizer,
constraint=self.beta_constraint,
trainable=True)
# print(self.beta)
self.kernel_b = self.add_weight(shape=(filters, filters // 8),
initializer=self.kernal_initializer,
name='kernel_b',
regularizer=self.kernal_regularizer,
constraint=self.kernal_constraint,
trainable=True)
self.kernel_c = self.add_weight(shape=(filters, filters // 8),
initializer=self.kernal_initializer,
name='kernel_c',
regularizer=self.kernal_regularizer,
constraint=self.kernal_constraint,
trainable=True)
self.kernel_d = self.add_weight(shape=(filters, filters),
initializer=self.kernal_initializer,
name='kernel_d',
regularizer=self.kernal_regularizer,
constraint=self.kernal_constraint,
trainable=True)
self.built = True
def compute_output_shape(self, input_shape):
return input_shape
def call(self, inputs):
input_shape = inputs.get_shape().as_list()
_, h, w, filters = input_shape
b = K.dot(inputs, self.kernel_b)
c = K.dot(inputs, self.kernel_c)
d = K.dot(inputs, self.kernel_d)
vec_b = K.reshape(b, (-1, h * w, filters // 8))
vec_cT = K.permute_dimensions(K.reshape(c, (-1, h * w, filters // 8)), (0, 2, 1))
bcT = K.batch_dot(vec_b, vec_cT)
softmax_bcT = Activation('softmax')(bcT)
vec_d = K.reshape(d, (-1, h * w, filters))
bcTd = K.batch_dot(softmax_bcT, vec_d)
bcTd = K.reshape(bcTd, (-1, h, w, filters))
out = self.beta * bcTd + inputs
# print(self.beta)
return out
# 通道注意
class CAM(Layer):
def __init__(self,
# gamma_initializer=tf.zeros_initializer()
gamma_initializer=keras.initializers.Zeros(),
gamma_regularizer=None,
gamma_constraint=None,
**kwargs):
super(CAM, self).__init__(**kwargs)
self.gamma_initializer = gamma_initializer
self.gamma_regularizer = gamma_regularizer
self.gamma_constraint = gamma_constraint
def build(self, input_shape):
self.gamma = self.add_weight(shape=(1,),
initializer=self.gamma_initializer,
name='gamma',
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
# print(self.gamma)
self.built = True
def compute_output_shape(self, input_shape):
return input_shape
def call(self, inputs):
input_shape = inputs.get_shape().as_list()
_, h, w, filters = input_shape
vec_a = K.reshape(inputs, (-1, h * w, filters))
vec_aT = K.permute_dimensions(K.reshape(vec_a, (-1, h * w, filters)), (0, 2, 1))
aTa = K.batch_dot(vec_aT, vec_a)
softmax_aTa = Activation('softmax')(aTa)
aaTa = K.batch_dot(vec_a, softmax_aTa)
aaTa = K.reshape(aaTa, (-1, h, w, filters))
out = self.gamma * aaTa + inputs
# print(self.gamma)
return out
# 使用方法
# pam = PAM()(reduce_conv5_3)
# cam = CAM()(reduce_conv5_3)
# feature_sum = add([pam, cam])
4. FrequencyChannelAttention
import math
import tensorflow as tf
import math
def get_freq_indices(method):
assert method in ['top1', 'top2', 'top4', 'top8', 'top16', 'top32',
'bot1', 'bot2', 'bot4', 'bot8', 'bot16', 'bot32',
'low1', 'low2', 'low4', 'low8', 'low16', 'low32']
num_freq = int(method[3:])
if 'top' in method:
all_top_indices_x = [0, 0, 6, 0, 0, 1, 1, 4, 5, 1, 3, 0, 0, 0, 3, 2, 4, 6, 3, 5, 5, 2, 6, 5, 5, 3, 3, 4, 2, 2,
6, 1]
all_top_indices_y = [0, 1, 0, 5, 2, 0, 2, 0, 0, 6, 0, 4, 6, 3, 5, 2, 6, 3, 3, 3, 5, 1, 1, 2, 4, 2, 1, 1, 3, 0,
5, 3]
mapper_x = all_top_indices_x[:num_freq]
mapper_y = all_top_indices_y[:num_freq]
elif 'low' in method:
all_low_indices_x = [0, 0, 1, 1, 0, 2, 2, 1, 2, 0, 3, 4, 0, 1, 3, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6, 1, 2,
3, 4]
all_low_indices_y = [0, 1, 0, 1, 2, 0, 1, 2, 2, 3, 0, 0, 4, 3, 1, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5,
4, 3]
mapper_x = all_low_indices_x[:num_freq]
mapper_y = all_low_indices_y[:num_freq]
elif 'bot' in method:
all_bot_indices_x = [6, 1, 3, 3, 2, 4, 1, 2, 4, 4, 5, 1, 4, 6, 2, 5, 6, 1, 6, 2, 2, 4, 3, 3, 5, 5, 6, 2, 5, 5,
3, 6]
all_bot_indices_y = [6, 4, 4, 6, 6, 3, 1, 4, 4, 5, 6, 5, 2, 2, 5, 1, 4, 3, 5, 0, 3, 1, 1, 2, 4, 2, 1, 1, 5, 3,
3, 3]
mapper_x = all_bot_indices_x[:num_freq]
mapper_y = all_bot_indices_y[:num_freq]
else:
raise NotImplementedError
return mapper_x, mapper_y
# 注意力层
def MultiSpectralAttentionLayer(x, channel, dct_h, dct_w, reduction=16, freq_sel_method='top2'):
print("------MultiSpectralAttentionLayer----start")
n, h, w, c = x.shape
x_pooled = x
mapper_x, mapper_y = get_freq_indices(freq_sel_method)
num_split = len(mapper_x)
mapper_x = [temp_x * (dct_h // 7) for temp_x in mapper_x]
mapper_y = [temp_y * (dct_w // 7) for temp_y in mapper_y]
y = MultiSpectralDCTLayer(x_pooled, dct_h, dct_w, mapper_x, mapper_y, channel)
y = tf.layers.dense(y, channel // reduction, activation=tf.nn.relu)
y = tf.layers.dense(y, channel)
y = tf.math.sigmoid(y)
y = tf.reshape(y, [n, 1, 1, c])
y = tf.transpose(y, (0, 3, 1, 2))
y = tf.tile(y, (1, 1, h, w))
print("------MultiSpectralAttentionLayer----end")
y = tf.transpose(y, (0, 2, 3, 1))
return x * y
def MultiSpectralDCTLayer(x, height, width, mapper_x, mapper_y, channel):
print("------MutilSpectralDCTLaer----start")
# assert len(mapper_x)==(mapper_y)
assert channel % len(mapper_x) == 0
num_freq = len(mapper_x)
weight = get_dct_filter(height, width, mapper_x, mapper_y, channel)
print(height)
print(width)
x = x * weight
result = tf.reduce_sum(x, [1, 2])
print("------MutilSpectralDCTLaer----end")
return result
def build_filter(pos, freq, POS):
# print("------build_filter----statr")
pi = tf.constant(math.pi)
POS = tf.cast(pos, tf.float32)
freq = tf.cast(freq, tf.float32)
POS = tf.cast(POS, tf.float32)
result = tf.math.cos(pi * freq * (pos + 0.5) / POS) / tf.math.sqrt(POS)
# print("------build_filter----end")
if freq == 0:
return result
else:
return result * tf.math.sqrt(tf.cast(2, tf.float32))
def get_dct_filter(tile_size_x, tile_size_y, mapper_x, mapper_y, channel):
print("------get_dct_filter----statr")
dct_filter = tf.Variable(tf.zeros([channel, tile_size_x, tile_size_y]))
c_part = channel // len(mapper_x)
for i, (u_x, v_y) in enumerate(zip(mapper_x, mapper_y)):
for t_x in range(tile_size_x):
for t_y in range(tile_size_y):
dct_filter[i * c_part: (i + 1) * c_part, t_x, t_y].assign(
build_filter(t_x, u_x, tile_size_x) * build_filter(t_y, v_y, tile_size_y))
dct_filter = tf.transpose(dct_filter, [1, 2, 0])
print("------get_dct_filter----end")
return dct_filter
5. BAM
# -*- coding: utf-8 -*-
import tensorflow as tf
import tensorflow.contrib.slim as slim
batch_norm_params = {
# Decay for moving averages
'decay': 0.995,
# epsilon to prevent 0 in variance
'epsilon': 0.001,
# force in-place updates of mean and variances estimates
'updates_collections': None,
# moving averages ends up in the trainable variables collection
'variables_collections': [tf.GraphKeys.TRAINABLE_VARIABLES]}
def BAM(inputs, batch_norm_params, reduction_ratio=16, dilation_value=4, reuse=None, scope='BAM'):
with tf.variable_scope(scope, reuse=reuse):
with slim.arg_scope([slim.conv2d, slim.fully_connected],
weights_initializer=slim.xavier_initializer(),
weights_regularizer=slim.l2_regularizer(0.0005)):
with slim.arg_scope([slim.conv2d], activation_fn=None):
input_channel = inputs.get_shape().as_list()[-1]
num_squeeze = input_channel // reduction_ratio
# Channel attention
gap = tf.reduce_mean(inputs, axis=[1, 2], keepdims=True)
channel = slim.fully_connected(gap, num_squeeze, activation_fn=None)
channel = slim.fully_connected(channel, input_channel, activation_fn=None,
normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params)
# Spatial attention
spatial = slim.conv2d(inputs, num_squeeze, 1, padding='SAME')
spatial = slim.repeat(spatial, 2, slim.conv2d, num_squeeze, 3, padding='SAME', rate=dilation_value)
spatial = slim.conv2d(spatial, 1, 1, padding='SAME',
normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params)
# combined two attention branch
combined = tf.nn.sigmoid(channel + spatial)
output = inputs + inputs * combined
return output
6.GlobalContext
"""
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
def conv(x, out_channel, kernel_size, stride=1, dilation=1):
x = slim.conv2d(x, out_channel, kernel_size, stride, rate=dilation,activation_fn=None)
return x
def global_avg_pool2D(x):
with tf.variable_scope(None, 'global_pool2D'):
n,h,w,c=x.get_shape().as_list
x = slim.avg_pool2d(x, (h,w), stride=1)
return x
def global_context_module(x,squeeze_depth,fuse_method='add',attention_method='att',scope=None):
assert fuse_method in ['add','mul']
assert attention_method in ['att','avg']
with tf.variable_scope(scope,"GCModule"):
if attention_method == 'avg':
context = global_avg_pool2D(x)#[N,1,1,C]
else:
n,h,w,c=x.get_shape().as_list()
context_mask = conv(x,1,1)# [N, H, W,1]
context_mask = tf.reshape(context_mask,shape=tf.convert_to_tensor([tf.shape(x)[0], -1, 1]))# [N, H*W, 1]
context_mask=tf.transpose(context_mask,perm=[0,2,1])# [N, 1, H*W]
context_mask = tf.nn.softmax(context_mask,axis=2)# [N, 1, H*W]
input_x = tf.reshape(x, shape=tf.convert_to_tensor([tf.shape(x)[0], -1,c]))# [N,H*W,C]
context=tf.matmul(context_mask,input_x)# [N, 1, H*W] x [N,H*W,C] =[N,1,C]
context=tf.expand_dims(context,axis=1)#[N,1,1,C]
context=conv(context,squeeze_depth,1)
context=slim.layer_norm(context)
context=tf.nn.relu(context)
context=conv(context,c,1)#[N,1,1,C]
if fuse_method=='mul':
context=tf.nn.sigmoid(context)
out=context*x
else:
out=context+x
return out
部分参考文献
[91]Wang Q , Wu B , Zhu P , et al. ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks[C]// 2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). IEEE, 2020.
[95]Woo S, Park J, Lee J Y, et al. Cbam: Convolutional block attention module[C]//Proceedings of the European conference on computer vision (ECCV). 2018: 3-19.
[105] Hou Q, Zhou D, Feng J. Coordinate attention for efficient mobile network design[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021: 13713-13722.
[106] Cao Y, Xu J, Lin S, et al. Gcnet: Non-local networks meet squeeze-excitation networks and beyond[C]//Proceedings of the IEEE/CVF International Conference on Computer Vision Workshops. 2019: 0-0.
[107] Li X, Wang W, Hu X, et al. Selective kernel networks[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019: 510-519.