语义分割CCNet-Criss Cross Network论文中注意力机制Criss Cross Attention模块的tensorflow代码实现

Criss Cross Attention 模块的tensorflow代码实现

也是边学习边写代码,如有问题和指正,请联系!!!

模块结构

在这里插入图片描述

Affinity 操作

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow import keras


class criss_cross_attention_Affinity(tf.keras.layers.Layer):

    def __init__(self, axis=1, **kwargs):
        super(criss_cross_attention_Affinity, self).__init__(**kwargs)
        self.axis = axis

    def call(self, x):
        batch_size, H, W, Channel = x.shape
        outputs = []
        for i in range(H):
            for j in range(W):
                ver = x[:, i, j, :]
                temp_x = tf.concat([x[:, i, 0:j, :], x[:, i, j + 1:W, :], x[:, :, j, :]], axis=1)
                trans_temp = tf.matmul(temp_x, tf.expand_dims(ver, -1))
                trans_temp = tf.squeeze(trans_temp, -1)
                trans_temp = tf.expand_dims(trans_temp, axis=1)
                outputs.append(trans_temp)
        outputs = layers.Concatenate(axis=self.axis)(outputs)
        C = outputs.shape[2]
        outputs = tf.reshape(outputs, [-1, H, W, C])
        return outputs

    def get_config(self):
        config = {'axis': self.axis}
        base_config = super(criss_cross_attention_Affinity, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

Aggregation操作

class criss_cross_attention_Aggregation(tf.keras.layers.Layer):

    def __init__(self, axis=1, **kwargs):
        super(criss_cross_attention_Aggregation, self).__init__(**kwargs)
        self.axis = axis

    def call(self, x, Affinity):
        batch_size, H, W, Channel = x.shape
        Affinity = layers.Activation('softmax')(Affinity)
        outputs = []
        for i in range(H):
            for j in range(W):
                ver = Affinity[:, i, j, :]
                temp_x = tf.concat([x[:, i, 0:j, :], x[:, i, j + 1:W, :], x[:, :, j, :]], axis=1)
                trans_temp = tf.matmul(tf.transpose(tf.expand_dims(ver, -1), [0, 2, 1]), temp_x)
                outputs.append(trans_temp)
        outputs = layers.Concatenate(axis=self.axis)(outputs)
        C = outputs.shape[2]
        outputs = tf.reshape(outputs, [-1, H, W, C])
        return outputs

    def get_config(self):
        config = {'axis': self.axis}
        base_config = super(criss_cross_attention_Aggregation, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

合并两个操作

def criss_cross_attention(x):
    x = layers.Conv2D(filters=64, kernel_size=3, padding='same', strides=2)(x)
    x_origin = x
    affinity = criss_cross_attention_Affinity(1)(x)
    out = criss_cross_attention_Aggregation(1)(x, affinity)
    out = layers.Add()([out, x_origin])
    out = layers.UpSampling2D(size=2, interpolation='bilinear')(out)
    return out

模型打印结果

在这里插入图片描述

问题

由于该模块是针对每一个像素点在原特征图上对应像素所在的十字行列像素上进行计算,所以代码写的是循环遍历每一个像素。导致计算复杂,暂时还没能解决这个问题。
在这里插入图片描述

  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 10
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值