基于Keras的Channel-Spatial Attention Layers的实现

基于Keras的Channel-Spatial Attention Layers的实现

描述
Attention被应用于医学图像分割领域以提高神经网络对空间特征及通道的关注度,从而提高分割任务的精确程度。本文根据参考文献实现一种Spatial Attention 和Channel Attention融合使用的注意力层。
在这里插入图片描述

代码展示
1.Spatial Attention Layer

class SpatialAttention(Layer):
    def __init__(self,outChan,feature_size_high,feature_size_low,feature_dim_high,mode,**kwargs):
        super(SpatialAttention,self).__init__()
        self.outChan = outChan
        self.feature_size_high = feature_size_high
        self.feature_size_low = feature_size_low
        self.feature_dim_high = feature_dim_high
        self.mode=mode
        if mode=='2D':
            inputs_high = Input(shape=(feature_size_high[0], feature_size_high[1] ,feature_dim_high))
            inputs_low = Input(shape=(feature_size_low[0], feature_size_low[1] ,outChan))
            Fx_ = Conv2D(feature_dim_high, 2, strides=2,padding = 'same')(inputs_low)
            Fy_ = Conv2D(feature_dim_high, 1, padding = 'same')(inputs_high)
            M_Spatial = Conv2D(feature_dim_high, 3, activation='relu',padding = 'same')(Fx_+Fy_)
            M_Spatial = Conv2D(outChan, 1, activation='sigmoid',padding = 'same')(M_Spatial)
            M_Spatial = UpSampling2D(size=(2,2))(M_Spatial)
            M_Spatial = Multiply()([M_Spatial,inputs_low])
        elif mode=='3D':
            inputs_high = Input(shape=(feature_size_high[0], feature_size_high[1] ,feature_size_high[2],feature_dim_high))
            inputs_low = Input(shape=(feature_size_low[0], feature_size_low[1] ,feature_size_low[2],outChan))
            Fx_ = Conv3D(feature_dim_high, (2,2,1), strides=(2,2,1),padding = 'same')(inputs_low)
            Fy_ = Conv3D(feature_dim_high, 1, padding = 'same')(inputs_high)
            M_Spatial = Conv3D(feature_dim_high, (3,3,1), activation='relu',padding = 'same')(Fx_+Fy_)
            M_Spatial = Conv3D(outChan, 1, activation='sigmoid',padding = 'same')(M_Spatial)
            M_Spatial = UpSampling3D(size=(2,2,1))(M_Spatial)
            M_Spatial = Multiply()([M_Spatial,inputs_low])
        self.SpatialAtten = Model(inputs=[inputs_high,inputs_low],outputs=M_Spatial)
        
    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                'outChan':self.outChan,
                'feature_size_high':self.feature_size_high,
                'feature_size_low':self.feature_size_low,
                'feature_dim_high':self.feature_dim_high,
                'SpatialAtten':self.SpatialAtten,
                'mode':self.mode
            }
        )
        return config
    
    def call(self,inputs):
        # high level feature:F_y  low level feature:F_x
        Fy,Fx = inputs
        M_Spatial = self.SpatialAtten([Fy,Fx])
        return M_Spatial

2.Channel Attention Layer

class ChannelAttention(Layer):
    def __init__(self,outChan,feature_size_low,feature_size_high,feature_dim_high,mode,**kwargs):
        super(ChannelAttention,self).__init__()
        self.outChan = outChan
        self.feature_size_high = feature_size_high
        self.feature_size_low = feature_size_low
        self.feature_dim_high = feature_dim_high
        self.mode = mode
        if mode=='2D':
            inputs_high = Input(shape=(feature_size_high[0], feature_size_high[1] ,feature_dim_high))
            inputs_low = Input(shape=(feature_size_low[0], feature_size_low[1] ,outChan))
            Fx_ = Conv2D(feature_dim_high, 1, padding = 'same')(inputs_low)
            Fy_ = Conv2D(feature_dim_high, 1, padding = 'same')(inputs_high)
            Fx_avepool = AveragePooling2D(pool_size=(2, 2))(Fx_) # None*16*16*1024
            Fy_avepool = AveragePooling2D(pool_size=(2, 2))(Fy_) # None*8*8*1024
            Nx = tf.cast(K.shape(Fx_avepool)[1]*K.shape(Fx_avepool)[2],dtype=tf.float32)
            Zx = K.sum(Fx_avepool,axis=2)
            Zx = K.sum(Zx,axis=1)/Nx # batchSize*2channels  # None*1024
            Ny = tf.cast(K.shape(Fy_avepool)[1]*K.shape(Fy_avepool)[2],dtype=tf.float32)
            Zy = K.sum(Fy_avepool,axis=2)
            Zy = K.sum(Zy,axis=1)/Ny # batchSize*2channels  # None*1024
            FC1 = Dense(feature_dim_high,activation='relu')(Zx+Zy) # None*1024
            FC2 = Dense(outChan,activation='sigmoid')(FC1) # batchSize*channels # None*512
            FC2 = K.expand_dims(FC2,axis=1)
            FC2 = K.expand_dims(FC2,axis=2) # batchSize*1*1*channels # None*1*1*512
            M_chan = K.repeat_elements(FC2,feature_size_low[0],axis=1)
            M_chan = K.repeat_elements(M_chan,feature_size_low[1],axis=2)
            M_chan = Multiply()([M_chan,inputs_low])
        elif mode=='3D':
            inputs_high = Input(shape=(feature_size_high[0], feature_size_high[1], feature_size_high[2], feature_dim_high))
            inputs_low = Input(shape=(feature_size_low[0], feature_size_low[1], feature_size_low[2], outChan))
            Fx_ = Conv3D(feature_dim_high, 1, padding = 'same')(inputs_low)
            Fy_ = Conv3D(feature_dim_high, 1, padding = 'same')(inputs_high)
            Fx_avepool = AveragePooling3D(pool_size=(2, 2, 1))(Fx_) # None*16*16*1024
            Fy_avepool = AveragePooling3D(pool_size=(2, 2, 1))(Fy_) # None*8*8*1024
            Nx = tf.cast(K.shape(Fx_avepool)[1]*K.shape(Fx_avepool)[2]*K.shape(Fx_avepool)[3],dtype=tf.float32)
            Zx = K.sum(Fx_avepool,axis=3)
            Zx = K.sum(Zx,axis=2)
            Zx = K.sum(Zx,axis=1)/Nx # batchSize*2channels  # None*1024
            Ny = tf.cast(K.shape(Fy_avepool)[1]*K.shape(Fy_avepool)[2]*K.shape(Fy_avepool)[3],dtype=tf.float32)
            Zy = K.sum(Fy_avepool,axis=3)
            Zy = K.sum(Zy,axis=2)
            Zy = K.sum(Zy,axis=1)/Ny # batchSize*2channels  # None*1024
            FC1 = Dense(feature_dim_high,activation='relu')(Zx+Zy) # None*1024
            FC2 = Dense(outChan,activation='sigmoid')(FC1) # batchSize*channels # None*512
            FC2 = K.expand_dims(FC2,axis=1)
            FC2 = K.expand_dims(FC2,axis=2) # batchSize*1*1*channels # None*1*1*512
            FC2 = K.expand_dims(FC2,axis=3)
            M_chan = K.repeat_elements(FC2,feature_size_low[0],axis=1)
            M_chan = K.repeat_elements(M_chan,feature_size_low[1],axis=2)
            M_chan = K.repeat_elements(M_chan,feature_size_low[2],axis=3)
            M_chan = Multiply()([M_chan,inputs_low])
        self.ChanAtten = Model(inputs=[inputs_high,inputs_low],outputs=M_chan)

    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                'outChan':self.outChan,
                'feature_size_high':self.feature_size_high,
                'feature_size_low':self.feature_size_low,
                'feature_dim_high':self.feature_dim_high,
                'ChanAtten':self.ChanAtten,
                'mode':self.mode
            }
        )
        return config
    
    def call(self,inputs):
        # high level feature:F_y  low level feature:F_x
        Fy,Fx = inputs
        M_chan = self.ChanAtten([Fy,Fx])
        return M_chan

3.Spatia-lChannel Attention Layer

class ChannelSpatialAttention(Layer):
    def __init__(self,outChan,feature_size_low,feature_size_high,feature_dim_high,mode,**kwargs):
        super(ChannelSpatialAttention,self).__init__()
        self.outChan = outChan
        self.feature_size_low = feature_size_low
        self.feature_size_high = feature_size_high
        self.feature_dim_high = feature_dim_high
        self.mode = mode
        self.ChannelAttention = ChannelAttention(outChan=outChan,feature_size_low=feature_size_low,feature_size_high=feature_size_high,feature_dim_high=feature_dim_high,mode=mode)
        self.SpatialAttention = SpatialAttention(outChan=outChan,feature_size_low=feature_size_low,feature_size_high=feature_size_high,feature_dim_high=feature_dim_high,mode=mode)

    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                'outChan':self.outChan,
                'feature_size_low':self.feature_size_low,
                'feature_size_high':self.feature_size_high,
                'feature_dim_high':self.feature_dim_high,
                'ChannelAttention':self.ChannelAttention,
                'SpatialAttention':self.SpatialAttention,
                'mode':self.mode
            }
        )
        return config
    
    def call(self,inputs):
        # high level feature:F_y  low level feature:F_x
        Fy,Fx = inputs
        M_Chan = self.ChannelAttention([Fy,Fx])
        M_Spatial = self.SpatialAttention([Fy,M_Chan])
        return M_Spatial

References

https://doi.org/10.1016/j.knosys.2021.106754

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值