CoTNet的keras实现

网上找了很久没找到CotLayer 的 keras实现
自己实现了一个
论文和源代码大家百度一下就可以了,别的博客都有
把resnet的3x3卷积换成cotlayer即可

class key_embed(tf.keras.layers.Layer):
    def __init__(self,dim):#dim=512,  kernel_size=3
        super(key_embed, self).__init__()

        # self.dim = dim
        # self.kernel_size = kernel_size
    # 通过K*K的卷积提取上下文信息,视作输入X的静态上下文表达
        self.convk1=tf.keras.layers.Conv2D(filters=dim, kernel_size=(3,3), padding='same')
        self.bnk1=tf.keras.layers.BatchNormalization()
        tf.keras.layers.ReLU()

    def call(self, inputs, **kwargs):
        x=self.convk1(inputs)
        x=self.bnk1(x)
        output = tf.nn.relu(x)
        return output
class value_embed(tf.keras.layers.Layer):
    def __init__(self,dim):# dim=512
        super(value_embed, self).__init__()
        self.convv1=tf.keras.layers.Conv2D(filters=dim, kernel_size = (1,1))  # 1*1的卷积进行Value的编码
        self.bnv1=tf.keras.layers.BatchNormalization()

    def call(self, inputs, **kwargs):
        x = self.convv1(inputs)
        output = self.bnv1(x)
        return output
class attention_embed(tf.keras.layers.Layer): # 通过连续两个1*1的卷积计算注意力矩阵
    def __init__(self,dim,kernel_size=3):#, dim=512,factor=4,kernel_size=3
        super(attention_embed, self).__init__()
        factor=4
        self.conva1=tf.keras.layers.Conv2D(2*dim//factor, kernel_size=(1,1))  # 输入concat后的特征矩阵 Channel = 2*C
        self.bna1=tf.keras.layers.BatchNormalization()
        tf.keras.layers.ReLU()
        self.conva2=tf.keras.layers.Conv2D(filters=kernel_size*kernel_size*dim, kernel_size=(1,1))  # out: H * W * (K*K*C)

    def call(self, inputs, **kwargs):
        x = self.conva1(inputs)
        x = self.bna1(x)
        x = tf.nn.relu(x)
        output=self.conva2(x)
        return output



class CoTNetLayer(tf.keras.layers.Layer):

    def __init__(self, dim=16, kernel_size=3):
        super().__init__()
        self.dim=dim
        self.key_embed = key_embed(dim=self.dim)
        self.value_embed = value_embed(dim=self.dim)
        self.kernel_size=kernel_size
        self.attention_embed = attention_embed(dim=self.dim)

    def call(self, x):
        bs, c, h, w = x.shape
        k1 = self.key_embed(x)  # shape:bs,h,w,c  提取静态上下文信息得到key
        z = self.value_embed(x)#.view(bs, c, -1)  # shape:bs,h*w,c  得到value编码
        v=rearrange(z,'bs h w c->bs (h w) c')
        y = tf.keras.layers.concatenate([k1, x], axis=-1)  # shape:bs,h,w,2c  Key与Query在channel维度上进行拼接进行拼接
        att = self.attention_embed(y)  # shape:bs,h,w,c*k*k  计算注意力矩阵
        att=rearrange(att,'bs h w (c c1)->bs h w c c1',c1=self.kernel_size*self.kernel_size)
        att = att.mean(4)
        att=rearrange(att,'bs h w c->bs (h w) c')  # shape:bs,h*w,c  求平均降低维度
        k2 = tf.nn.softmax(att, axis=-2) * v  # 对每一个H*w进行softmax后
        h1=h
        k2=rearrange(k2,'bs (h w) c->bs h w c',h=h1)

        return k1 + k2  # 注意力融合
  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值