【附代码实现】Attention注意力模块的keras\tf实现(ECA、BAM、Coordinate、DualAttention、GlobalContext等)

前言

研究生阶段的一些工作、因为涉及到了注意力方面的研究,所以复现了一些比较出名的注意力模块,这些都是我和朋友根据自己理解复现的,用的是keras,不保证复现的正确性,欢迎交流。

1. ECA


https://blog.csdn.net/qq_35054151/article/details/115434812import math
from keras.layers import *
from keras.layers import Activation
from keras.layers import GlobalAveragePooling2D
import keras.backend as K
import tensorflow as tf
def eca_layer(inputs_tensor=None,num=None,gamma=2,b=1):
    """
    注意力模块-NET
    :param inputs_tensor: input_tensor.shape=[batchsize,h,w,channels]
    :param num:
    :param gamma:
    :param b:
    :return:
    """
    channels = K.int_shape(inputs_tensor)[-1]
    t = int(abs((math.log(channels,2)+b)/gamma))
    k = t if t%2 else t+1
    x_global_avg_pool = GlobalAveragePooling2D()(inputs_tensor)
    x = Reshape((channels,1))(x_global_avg_pool)
    x = Conv1D(1, kernel_size=k,padding="same",name="eca_conv1_" + str(num))(x)
    x = Activation('sigmoid', name='eca_conv1_relu_' + str(num))(x)  #shape=[batch,chnnels,1]
    x = Reshape((1, 1, channels))(x)
    output = multiply([inputs_tensor,x])
    return output

2. Coordinate attention

import tensorflow as tf
from keras.layers import Lambda,Concatenate,Reshape,Conv2D,BatchNormalization,Activation,Multiply,Add

def coordinate(inputs,ratio=2, name="name"):
    W,H,C = [int(x) for x in inputs.shape[1:]]
    temp_dim = max(int(C//ratio),ratio)
    H_pool = Lambda(lambda x: tf.reduce_mean(x, axis=1))(inputs)
    W_pool = Lambda(lambda x: tf.reduce_mean(x, axis=2))(inputs)
    x = Concatenate(axis=1)([H_pool,W_pool])
    x = Reshape((1,W+H,C))(x)
    x = Conv2D(temp_dim,1, name=name+'1')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x_h,x_w = Lambda(lambda x:tf.split(x,[H,W],axis=2))(x)
    x_w = Reshape((W,1,temp_dim))(x_w)

    x_h = Conv2D(C,1,activation='sigmoid',name=name+"2")(x_h)
    x_w = Conv2D(C, 1, activation='sigmoid',name=name+"3")(x_w)
    x = Multiply()([inputs,x_h,x_w])
    x = Add()([inputs,x])
    return x

3. Dual attention


import keras
from keras.layers import Activation, Conv2D
import keras.backend as K
import tensorflow as tf
from keras.layers import Layer


#  位置注意
class PAM(Layer):
    def __init__(self,
                 # beta_initializer=tf.zeros_initializer()
                 beta_initializer=keras.initializers.Zeros(),
                 beta_regularizer=None,
                 beta_constraint=None,
                 kernal_initializer='he_normal',
                 kernal_regularizer=None,
                 kernal_constraint=None,
                 **kwargs):
        super(PAM, self).__init__(**kwargs)

        self.beta_initializer = beta_initializer
        self.beta_regularizer = beta_regularizer
        self.beta_constraint = beta_constraint

        self.kernal_initializer = kernal_initializer
        self.kernal_regularizer = kernal_regularizer
        self.kernal_constraint = kernal_constraint

    def build(self, input_shape):
        _, h, w, filters = input_shape

        self.beta = self.add_weight(shape=(1,),
                                    initializer=self.beta_initializer,
                                    name='beta',
                                    regularizer=self.beta_regularizer,
                                    constraint=self.beta_constraint,
                                    trainable=True)
        # print(self.beta)

        self.kernel_b = self.add_weight(shape=(filters, filters // 8),
                                        initializer=self.kernal_initializer,
                                        name='kernel_b',
                                        regularizer=self.kernal_regularizer,
                                        constraint=self.kernal_constraint,
                                        trainable=True)

        self.kernel_c = self.add_weight(shape=(filters, filters // 8),
                                        initializer=self.kernal_initializer,
                                        name='kernel_c',
                                        regularizer=self.kernal_regularizer,
                                        constraint=self.kernal_constraint,
                                        trainable=True)

        self.kernel_d = self.add_weight(shape=(filters, filters),
                                        initializer=self.kernal_initializer,
                                        name='kernel_d',
                                        regularizer=self.kernal_regularizer,
                                        constraint=self.kernal_constraint,
                                        trainable=True)

        self.built = True

    def compute_output_shape(self, input_shape):
        return input_shape

    def call(self, inputs):
        input_shape = inputs.get_shape().as_list()
        _, h, w, filters = input_shape

        b = K.dot(inputs, self.kernel_b)
        c = K.dot(inputs, self.kernel_c)
        d = K.dot(inputs, self.kernel_d)
        vec_b = K.reshape(b, (-1, h * w, filters // 8))
        vec_cT = K.permute_dimensions(K.reshape(c, (-1, h * w, filters // 8)), (0, 2, 1))
        bcT = K.batch_dot(vec_b, vec_cT)
        softmax_bcT = Activation('softmax')(bcT)
        vec_d = K.reshape(d, (-1, h * w, filters))
        bcTd = K.batch_dot(softmax_bcT, vec_d)
        bcTd = K.reshape(bcTd, (-1, h, w, filters))

        out = self.beta * bcTd + inputs
        # print(self.beta)
        return out

#  通道注意
class CAM(Layer):
    def __init__(self,
                 # gamma_initializer=tf.zeros_initializer()
                 gamma_initializer=keras.initializers.Zeros(),
                 gamma_regularizer=None,
                 gamma_constraint=None,
                 **kwargs):
        super(CAM, self).__init__(**kwargs)
        self.gamma_initializer = gamma_initializer
        self.gamma_regularizer = gamma_regularizer
        self.gamma_constraint = gamma_constraint

    def build(self, input_shape):
        self.gamma = self.add_weight(shape=(1,),
                                     initializer=self.gamma_initializer,
                                     name='gamma',
                                     regularizer=self.gamma_regularizer,
                                     constraint=self.gamma_constraint)
        # print(self.gamma)

        self.built = True

    def compute_output_shape(self, input_shape):
        return input_shape

    def call(self, inputs):
        input_shape = inputs.get_shape().as_list()
        _, h, w, filters = input_shape

        vec_a = K.reshape(inputs, (-1, h * w, filters))
        vec_aT = K.permute_dimensions(K.reshape(vec_a, (-1, h * w, filters)), (0, 2, 1))
        aTa = K.batch_dot(vec_aT, vec_a)
        softmax_aTa = Activation('softmax')(aTa)
        aaTa = K.batch_dot(vec_a, softmax_aTa)
        aaTa = K.reshape(aaTa, (-1, h, w, filters))

        out = self.gamma * aaTa + inputs
        # print(self.gamma)
        return out


#  使用方法
# pam = PAM()(reduce_conv5_3)
# cam = CAM()(reduce_conv5_3)
# feature_sum = add([pam, cam])

4. FrequencyChannelAttention



import math
import tensorflow as tf
import math


def get_freq_indices(method):
    assert method in ['top1', 'top2', 'top4', 'top8', 'top16', 'top32',
                      'bot1', 'bot2', 'bot4', 'bot8', 'bot16', 'bot32',
                      'low1', 'low2', 'low4', 'low8', 'low16', 'low32']
    num_freq = int(method[3:])
    if 'top' in method:
        all_top_indices_x = [0, 0, 6, 0, 0, 1, 1, 4, 5, 1, 3, 0, 0, 0, 3, 2, 4, 6, 3, 5, 5, 2, 6, 5, 5, 3, 3, 4, 2, 2,
                             6, 1]
        all_top_indices_y = [0, 1, 0, 5, 2, 0, 2, 0, 0, 6, 0, 4, 6, 3, 5, 2, 6, 3, 3, 3, 5, 1, 1, 2, 4, 2, 1, 1, 3, 0,
                             5, 3]
        mapper_x = all_top_indices_x[:num_freq]
        mapper_y = all_top_indices_y[:num_freq]
    elif 'low' in method:
        all_low_indices_x = [0, 0, 1, 1, 0, 2, 2, 1, 2, 0, 3, 4, 0, 1, 3, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6, 1, 2,
                             3, 4]
        all_low_indices_y = [0, 1, 0, 1, 2, 0, 1, 2, 2, 3, 0, 0, 4, 3, 1, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5,
                             4, 3]
        mapper_x = all_low_indices_x[:num_freq]
        mapper_y = all_low_indices_y[:num_freq]
    elif 'bot' in method:
        all_bot_indices_x = [6, 1, 3, 3, 2, 4, 1, 2, 4, 4, 5, 1, 4, 6, 2, 5, 6, 1, 6, 2, 2, 4, 3, 3, 5, 5, 6, 2, 5, 5,
                             3, 6]
        all_bot_indices_y = [6, 4, 4, 6, 6, 3, 1, 4, 4, 5, 6, 5, 2, 2, 5, 1, 4, 3, 5, 0, 3, 1, 1, 2, 4, 2, 1, 1, 5, 3,
                             3, 3]
        mapper_x = all_bot_indices_x[:num_freq]
        mapper_y = all_bot_indices_y[:num_freq]
    else:
        raise NotImplementedError
    return mapper_x, mapper_y


#  注意力层
def MultiSpectralAttentionLayer(x, channel, dct_h, dct_w, reduction=16, freq_sel_method='top2'):
    print("------MultiSpectralAttentionLayer----start")
    n, h, w, c = x.shape
    x_pooled = x
    mapper_x, mapper_y = get_freq_indices(freq_sel_method)
    num_split = len(mapper_x)
    mapper_x = [temp_x * (dct_h // 7) for temp_x in mapper_x]
    mapper_y = [temp_y * (dct_w // 7) for temp_y in mapper_y]
    y = MultiSpectralDCTLayer(x_pooled, dct_h, dct_w, mapper_x, mapper_y, channel)
    y = tf.layers.dense(y, channel // reduction, activation=tf.nn.relu)
    y = tf.layers.dense(y, channel)
    y = tf.math.sigmoid(y)
    y = tf.reshape(y, [n, 1, 1, c])
    y = tf.transpose(y, (0, 3, 1, 2))
    y = tf.tile(y, (1, 1, h, w))
    print("------MultiSpectralAttentionLayer----end")
    y = tf.transpose(y, (0, 2, 3, 1))
    return x * y


def MultiSpectralDCTLayer(x, height, width, mapper_x, mapper_y, channel):
    print("------MutilSpectralDCTLaer----start")
    # assert len(mapper_x)==(mapper_y)
    assert channel % len(mapper_x) == 0
    num_freq = len(mapper_x)
    weight = get_dct_filter(height, width, mapper_x, mapper_y, channel)
    print(height)
    print(width)
    x = x * weight
    result = tf.reduce_sum(x, [1, 2])
    print("------MutilSpectralDCTLaer----end")
    return result


def build_filter(pos, freq, POS):
    # print("------build_filter----statr")
    pi = tf.constant(math.pi)
    POS = tf.cast(pos, tf.float32)
    freq = tf.cast(freq, tf.float32)
    POS = tf.cast(POS, tf.float32)
    result = tf.math.cos(pi * freq * (pos + 0.5) / POS) / tf.math.sqrt(POS)
    # print("------build_filter----end")
    if freq == 0:
        return result
    else:
        return result * tf.math.sqrt(tf.cast(2, tf.float32))


def get_dct_filter(tile_size_x, tile_size_y, mapper_x, mapper_y, channel):
    print("------get_dct_filter----statr")
    dct_filter = tf.Variable(tf.zeros([channel, tile_size_x, tile_size_y]))
    c_part = channel // len(mapper_x)

    for i, (u_x, v_y) in enumerate(zip(mapper_x, mapper_y)):
        for t_x in range(tile_size_x):
            for t_y in range(tile_size_y):
                dct_filter[i * c_part: (i + 1) * c_part, t_x, t_y].assign(
                    build_filter(t_x, u_x, tile_size_x) * build_filter(t_y, v_y, tile_size_y))
    dct_filter = tf.transpose(dct_filter, [1, 2, 0])
    print("------get_dct_filter----end")
    return dct_filter

5. BAM

# -*- coding: utf-8 -*-


import tensorflow as tf
import tensorflow.contrib.slim as slim

batch_norm_params = {
                    # Decay for moving averages
                    'decay': 0.995,
                    # epsilon to prevent 0 in variance
                    'epsilon': 0.001,
                    # force in-place updates of mean and variances estimates
                    'updates_collections': None,
                    # moving averages ends up in the trainable variables collection
                    'variables_collections': [tf.GraphKeys.TRAINABLE_VARIABLES]}

def BAM(inputs, batch_norm_params, reduction_ratio=16, dilation_value=4, reuse=None, scope='BAM'):
    with tf.variable_scope(scope, reuse=reuse):
        with slim.arg_scope([slim.conv2d, slim.fully_connected],
                            weights_initializer=slim.xavier_initializer(),
                            weights_regularizer=slim.l2_regularizer(0.0005)):
            with slim.arg_scope([slim.conv2d], activation_fn=None):
                input_channel = inputs.get_shape().as_list()[-1]
                num_squeeze = input_channel // reduction_ratio

                # Channel attention
                gap = tf.reduce_mean(inputs, axis=[1, 2], keepdims=True)
                channel = slim.fully_connected(gap, num_squeeze, activation_fn=None)
                channel = slim.fully_connected(channel, input_channel, activation_fn=None,
                                               normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params)

                # Spatial attention
                spatial = slim.conv2d(inputs, num_squeeze, 1, padding='SAME')
                spatial = slim.repeat(spatial, 2, slim.conv2d, num_squeeze, 3, padding='SAME', rate=dilation_value)
                spatial = slim.conv2d(spatial, 1, 1, padding='SAME',
                                      normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params)

                # combined two attention branch
                combined = tf.nn.sigmoid(channel + spatial)

                output = inputs + inputs * combined

                return output

6.GlobalContext

"""
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np

def conv(x, out_channel, kernel_size, stride=1, dilation=1):
    x = slim.conv2d(x, out_channel, kernel_size, stride, rate=dilation,activation_fn=None)
    return x

def global_avg_pool2D(x):
    with tf.variable_scope(None, 'global_pool2D'):
        n,h,w,c=x.get_shape().as_list
        x = slim.avg_pool2d(x, (h,w), stride=1)
    return x

def global_context_module(x,squeeze_depth,fuse_method='add',attention_method='att',scope=None):

    assert fuse_method in ['add','mul']
    assert attention_method in ['att','avg']

    with tf.variable_scope(scope,"GCModule"):

        if attention_method == 'avg':
            context = global_avg_pool2D(x)#[N,1,1,C]
        else:
            n,h,w,c=x.get_shape().as_list()
            context_mask = conv(x,1,1)# [N, H, W,1]
            context_mask = tf.reshape(context_mask,shape=tf.convert_to_tensor([tf.shape(x)[0], -1, 1]))# [N, H*W, 1]
            context_mask=tf.transpose(context_mask,perm=[0,2,1])# [N, 1, H*W]
            context_mask = tf.nn.softmax(context_mask,axis=2)# [N, 1, H*W]

            input_x = tf.reshape(x, shape=tf.convert_to_tensor([tf.shape(x)[0], -1,c]))# [N,H*W,C]

            context=tf.matmul(context_mask,input_x)# [N, 1, H*W] x [N,H*W,C] =[N,1,C]
            context=tf.expand_dims(context,axis=1)#[N,1,1,C]

        context=conv(context,squeeze_depth,1)
        context=slim.layer_norm(context)
        context=tf.nn.relu(context)
        context=conv(context,c,1)#[N,1,1,C]

        if fuse_method=='mul':
            context=tf.nn.sigmoid(context)
            out=context*x
        else:
            out=context+x

        return out

部分参考文献

[91]Wang Q ,  Wu B ,  Zhu P , et al. ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks[C]// 2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). IEEE, 2020.
[95]Woo S, Park J, Lee J Y, et al. Cbam: Convolutional block attention module[C]//Proceedings of the European conference on computer vision (ECCV). 2018: 3-19.
[105] Hou Q, Zhou D, Feng J. Coordinate attention for efficient mobile network design[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021: 13713-13722.
[106] Cao Y, Xu J, Lin S, et al. Gcnet: Non-local networks meet squeeze-excitation networks and beyond[C]//Proceedings of the IEEE/CVF International Conference on Computer Vision Workshops. 2019: 0-0.
[107] Li X, Wang W, Hu X, et al. Selective kernel networks[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019: 510-519.

  • 13
    点赞
  • 69
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值