PSA极注意力 - tensorflow2代码

 论文中有两种,上方的为PSA_P, 下方的为PSA_S。

import tensorflow as tf

def kaiming_init(module, distribution='normal'):
    assert distribution in ['uniform', 'normal']
    if distribution == 'uniform':
        module.kernel_initializer = tf.keras.initializers.he_uniform() # l
    else:
        module.kernel_initializer = tf.keras.initializers.he_normal() # z
    if hasattr(module, 'bias') and module.bias is not None:
        module.kernel_initializer = tf.keras.initializers.constant() # z


class PSA_p(tf.keras.Model):
    def __init__(self, planes, data_format='channels_last'):
        """
        :param planes:  输入的通道数
        :param data_format:  数据格式,默认为 channels_last, 可选 channels_first
        """
        super(PSA_p, self).__init__()
        self.data_format = data_format
        self.planes = planes
        self.out_planes = planes // 2

        self.conv_q_left = tf.keras.layers.Conv2D(filters=1, kernel_size=1, strides=1, padding='valid', use_bias=False)
        self.conv_v_left = tf.keras.layers.Conv2D(filters=self.out_planes, kernel_size=1, strides=1, padding='valid',
                                                  use_bias=False)
        self.conv_up_left = tf.keras.layers.Conv2D(filters=self.planes, kernel_size=1, strides=1, padding='valid',
                                                   use_bias=False)
        self.softmax_left = tf.keras.layers.Activation(activation=tf.keras.activations.softmax)
        self.sigmoid_left = tf.keras.layers.Activation(activation=tf.keras.activations.sigmoid)

        self.conv_q_right = tf.keras.layers.Conv2D(filters=self.out_planes, kernel_size=1, strides=1, padding='valid', use_bias=False)
        self.conv_v_right = tf.keras.layers.Conv2D(filters=self.out_planes, kernel_size=1, strides=1, padding='valid', use_bias=False)
        self.global_pool = tf.keras.layers.GlobalAvgPool2D(keepdims=True)
        self.softmax_right = tf.keras.layers.Activation(activation=tf.keras.activations.softmax)
        self.sigmoid_right = tf.keras.layers.Activation(activation=tf.keras.activations.sigmoid)

        self.reset_parameters() # l

    def reset_parameters(self): # z
        kaiming_init(self.conv_v_right)
        kaiming_init(self.conv_q_right)
        kaiming_init(self.conv_v_left)
        kaiming_init(self.conv_q_left)

    def channel_pool(self, inputs): # z
        if self.data_format == 'channels_first':
            inputs = tf.transpose(inputs, perm=[0, 2, 3, 1])  # (B, C, H, W) -> (B, H, W, C)

        #   (B, H, W, IC) -> (B, H, W, IC/2)
        input_x = self.conv_v_left(inputs)
        B, H, W, C = tf.shape(input_x)
        #   (B, H, W, C) -> (B, H*W, C)
        input_x = tf.reshape(input_x, shape=(B, H * W, C))

        #   (B, H, W, IC) -> (B, H, W, 1)
        context_mask = self.conv_q_left(inputs)
        #   (B, H, W, 1) -> (B, H*W, 1)
        context_mask = tf.reshape(context_mask, shape=(B, H * W, 1))
        #   (B, H*W, 1) -> (B, H*W, 1)
        context_mask = self.softmax_left(context_mask)
        #   (B, C, H*W) 点乘 (B, H*W, 1) -> (B, C, 1)
        context = tf.matmul(a=tf.transpose(input_x, perm=[0, 2, 1]), b=context_mask)
        #   (B, C, 1) -> (B, C, 1, 1)
        context = tf.expand_dims(context, axis=-1)
        #   (B, C, 1, 1) -> (B, 1, 1, C)
        context = tf.transpose(context, perm=[0, 2, 3, 1])
        #   (B, 1, 1, C) -> (B, 1, 1, OC)
        context = self.conv_up_left(context)  # 恢复输入时的通道数
        #   (B, 1, 1, OC) -> (B, 1, 1, OC)
        mask_ch = self.sigmoid_left(context)

        out = inputs * mask_ch
        if self.data_format == 'channels_first':
            #   (B, H, W, C) -> (B, C, H, W)
            out = tf.transpose(out, [0, 3, 1, 2])

        return out

    def spatial_pool(self, inputs):
        if self.data_format == 'channels_first':
            inputs = tf.transpose(inputs, perm=[0, 2, 3, 1])  # (B, C, H, W) -> (B, H, W, C)
        #   (B, H, W, C) -> (B, H, W, C/2)
        g_x = self.conv_q_right(inputs)
        #   (B, H, W, C/2)-> (B, 1, 1, C/2)
        avg_x = self.global_pool(g_x)
        B, H, W, C = tf.shape(avg_x)
        #   (B, 1, 1, C/2) -> (B, 1, C/2)
        avg_x = tf.reshape(tensor=avg_x, shape=(B, H*W, C))

        #   (B, H, W, C) -> (B, H, W, C/2)
        g_v = self.conv_v_right(inputs)
        V_B, V_H, V_W, V_C = tf.shape(g_v)
        #   (B, H, W, C/2) -> (B, H*W, C/2)
        theta_x = tf.reshape(tensor=g_v, shape=(V_B, V_H*V_W, V_C))
        #   (B, 1, C/2) * (B, C/2, H*W) -> (B, 1, H*W)
        context = tf.matmul(avg_x, tf.transpose(a=theta_x, perm=[0, 2, 1]))
        #   (B, 1, H*W) -> (B, 1, H*W)
        context = self.softmax_right(context)   # 虽然论文中的图片,是avg_x先softmax,再点乘,但是代码中却是先点乘,再softmax
        #   (B, 1, H*W) -> (B, 1, H, W)
        context = tf.reshape(context, shape=(V_B, 1, V_H, V_W))
        #   (B, 1, H, W) -> (B, 1, H, W)
        context = self.sigmoid_right(context)
        #   (B, 1, H, W) -> (B, H, W, 1)
        context = tf.transpose(context, [0, 2, 3, 1])

        out = inputs * context
        if self.data_format == 'channels_first':
            out = tf.transpose(out, [0, 3, 1, 2])
        return out

    def call(self, inputs):
        #   空间注意
        context_spatial = self.spatial_pool(inputs)
        #   通道注意
        context_channel = self.channel_pool(inputs)

        out = context_spatial + context_channel
        return out


class PSA_s(tf.keras.Model):
    def __init__(self, planes, data_format='channels_last'):
        """
        :param planes:  输入的通道数
        :param data_format:  数据格式,默认为 channels_last, 可选 channels_first
        """
        super(PSA_s, self).__init__()
        self.data_format = data_format
        self.planes = planes
        self.out_planes = planes // 2
        ratio = 4

        self.conv_q_left = tf.keras.layers.Conv2D(filters=1, kernel_size=1, strides=1, padding='valid', use_bias=False)
        self.conv_v_left = tf.keras.layers.Conv2D(filters=self.out_planes, kernel_size=1, strides=1, padding='valid',
                                                  use_bias=False)

        self.conv_up_left = tf.keras.Sequential([
            tf.keras.layers.Conv2D(filters=self.out_planes//ratio, kernel_size=1, strides=1, padding='valid'),
            tf.keras.layers.LayerNormalization(),
            tf.keras.layers.Activation(tf.keras.activations.relu),
            tf.keras.layers.Conv2D(filters=self.planes, kernel_size=1, strides=1, padding='valid')
        ])

        self.softmax_left = tf.keras.layers.Activation(activation=tf.keras.activations.softmax)
        self.sigmoid_left = tf.keras.layers.Activation(activation=tf.keras.activations.sigmoid)

        self.conv_q_right = tf.keras.layers.Conv2D(filters=self.out_planes, kernel_size=1, strides=1, padding='valid',
                                                   use_bias=False)
        self.conv_v_right = tf.keras.layers.Conv2D(filters=self.out_planes, kernel_size=1, strides=1, padding='valid',
                                                   use_bias=False)
        self.global_pool = tf.keras.layers.GlobalAvgPool2D(keepdims=True)
        self.softmax_right = tf.keras.layers.Activation(activation=tf.keras.activations.softmax)
        self.sigmoid_right = tf.keras.layers.Activation(activation=tf.keras.activations.sigmoid)

    def reset_parameters(self):
        kaiming_init(self.conv_v_right)
        kaiming_init(self.conv_q_right)
        kaiming_init(self.conv_v_left)
        kaiming_init(self.conv_q_left)

    def channel_pool(self, inputs):
        if self.data_format == 'channels_first':
            inputs = tf.transpose(inputs, perm=[0, 2, 3, 1])  # (B, C, H, W) -> (B, H, W, C)

        #   (B, H, W, IC) -> (B, H, W, IC/2)
        input_x = self.conv_v_left(inputs)
        B, H, W, C = tf.shape(input_x)
        #   (B, H, W, C) -> (B, H*W, C)
        input_x = tf.reshape(input_x, shape=(B, H * W, C))

        #   (B, H, W, IC) -> (B, H, W, 1)
        context_mask = self.conv_q_left(inputs)
        #   (B, H, W, 1) -> (B, H*W, 1)
        context_mask = tf.reshape(context_mask, shape=(B, H * W, 1))
        #   (B, H*W, 1) -> (B, H*W, 1)
        context_mask = self.softmax_left(context_mask)
        #   (B, C, H*W) 点乘 (B, H*W, 1) -> (B, C, 1)
        context = tf.matmul(a=tf.transpose(input_x, perm=[0, 2, 1]), b=context_mask)
        #   (B, C, 1) -> (B, C, 1, 1)
        context = tf.expand_dims(context, axis=-1)
        #   (B, C, 1, 1) -> (B, 1, 1, C)
        context = tf.transpose(context, perm=[0, 2, 3, 1])

        #   (B, 1, 1, C) -> (B, 1, 1, OC)
        context = self.conv_up_left(context)  # 恢复输入时的通道数
        #   (B, 1, 1, OC) -> (B, 1, 1, OC)
        mask_ch = self.sigmoid_left(context)

        out = inputs * mask_ch
        if self.data_format == 'channels_first':
            #   (B, H, W, C) -> (B, C, H, W)
            out = tf.transpose(out, [0, 3, 1, 2])

        return out

    def spatial_pool(self, inputs):
        if self.data_format == 'channels_first':
            inputs = tf.transpose(inputs, perm=[0, 2, 3, 1])  # (B, C, H, W) -> (B, H, W, C)
        #   (B, H, W, C) -> (B, H, W, C/2)
        g_x = self.conv_q_right(inputs)
        #   (B, H, W, C/2)-> (B, 1, 1, C/2)
        avg_x = self.global_pool(g_x)
        B, H, W, C = tf.shape(avg_x)
        #   (B, 1, 1, C/2) -> (B, 1, C/2)
        avg_x = tf.reshape(tensor=avg_x, shape=(B, H * W, C))

        #   (B, H, W, C) -> (B, H, W, C/2)
        g_v = self.conv_v_right(inputs)
        V_B, V_H, V_W, V_C = tf.shape(g_v)
        #   (B, H, W, C/2) -> (B, H*W, C/2)
        theta_x = tf.reshape(tensor=g_v, shape=(V_B, V_H * V_W, V_C))
        #   (B, 1, C/2) * (B, C/2, H*W) -> (B, 1, H*W)
        context = tf.matmul(avg_x, tf.transpose(a=theta_x, perm=[0, 2, 1]))
        #   (B, 1, H*W) -> (B, 1, H*W)
        context = self.softmax_right(context)  # 虽然论文中的图片,是avg_x先softmax,再点乘,但是代码中却是先点乘,再softmax
        #   (B, 1, H*W) -> (B, 1, H, W)
        context = tf.reshape(context, shape=(V_B, 1, V_H, V_W))
        #   (B, 1, H, W) -> (B, 1, H, W)
        context = self.sigmoid_right(context)
        #   (B, 1, H, W) -> (B, H, W, 1)
        context = tf.transpose(context, [0, 2, 3, 1])

        out = inputs * context
        if self.data_format == 'channels_first':
            out = tf.transpose(out, [0, 3, 1, 2])
        return out

    def call(self, inputs):
        #   空间注意
        out = self.spatial_pool(inputs)
        #   通道注意
        out = self.channel_pool(out)

        return out


if __name__ == '__main__':
    a = tf.zeros(shape=(4, 224, 224, 16))
    psa_p = PSA_p(planes=16)(a)
    print("PSA_p output shape = {}".format(tf.shape(psa_p).numpy()))
    psa_s = PSA_s(planes=16)(a)
    print("PSA_s output shape = {}".format(tf.shape(psa_s).numpy()))


  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

早起学习晚上搬砖

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值