网上找了很久没找到CotLayer 的 keras实现
自己实现了一个
论文和源代码大家百度一下就可以了,别的博客都有
把resnet的3x3卷积换成cotlayer即可
class key_embed(tf.keras.layers.Layer):
def __init__(self,dim):#dim=512, kernel_size=3
super(key_embed, self).__init__()
# self.dim = dim
# self.kernel_size = kernel_size
# 通过K*K的卷积提取上下文信息,视作输入X的静态上下文表达
self.convk1=tf.keras.layers.Conv2D(filters=dim, kernel_size=(3,3), padding='same')
self.bnk1=tf.keras.layers.BatchNormalization()
tf.keras.layers.ReLU()
def call(self, inputs, **kwargs):
x=self.convk1(inputs)
x=self.bnk1(x)
output = tf.nn.relu(x)
return output
class value_embed(tf.keras.layers.Layer):
def __init__(self,dim):# dim=512
super(value_embed, self).__init__()
self.convv1=tf.keras.layers.Conv2D(filters=dim, kernel_size = (1,1)) # 1*1的卷积进行Value的编码
self.bnv1=tf.keras.layers.BatchNormalization()
def call(self, inputs, **kwargs):
x = self.convv1(inputs)
output = self.bnv1(x)
return output
class attention_embed(tf.keras.layers.Layer): # 通过连续两个1*1的卷积计算注意力矩阵
def __init__(self,dim,kernel_size=3):#, dim=512,factor=4,kernel_size=3
super(attention_embed, self).__init__()
factor=4
self.conva1=tf.keras.layers.Conv2D(2*dim//factor, kernel_size=(1,1)) # 输入concat后的特征矩阵 Channel = 2*C
self.bna1=tf.keras.layers.BatchNormalization()
tf.keras.layers.ReLU()
self.conva2=tf.keras.layers.Conv2D(filters=kernel_size*kernel_size*dim, kernel_size=(1,1)) # out: H * W * (K*K*C)
def call(self, inputs, **kwargs):
x = self.conva1(inputs)
x = self.bna1(x)
x = tf.nn.relu(x)
output=self.conva2(x)
return output
class CoTNetLayer(tf.keras.layers.Layer):
def __init__(self, dim=16, kernel_size=3):
super().__init__()
self.dim=dim
self.key_embed = key_embed(dim=self.dim)
self.value_embed = value_embed(dim=self.dim)
self.kernel_size=kernel_size
self.attention_embed = attention_embed(dim=self.dim)
def call(self, x):
bs, c, h, w = x.shape
k1 = self.key_embed(x) # shape:bs,h,w,c 提取静态上下文信息得到key
z = self.value_embed(x)#.view(bs, c, -1) # shape:bs,h*w,c 得到value编码
v=rearrange(z,'bs h w c->bs (h w) c')
y = tf.keras.layers.concatenate([k1, x], axis=-1) # shape:bs,h,w,2c Key与Query在channel维度上进行拼接进行拼接
att = self.attention_embed(y) # shape:bs,h,w,c*k*k 计算注意力矩阵
att=rearrange(att,'bs h w (c c1)->bs h w c c1',c1=self.kernel_size*self.kernel_size)
att = att.mean(4)
att=rearrange(att,'bs h w c->bs (h w) c') # shape:bs,h*w,c 求平均降低维度
k2 = tf.nn.softmax(att, axis=-2) * v # 对每一个H*w进行softmax后
h1=h
k2=rearrange(k2,'bs (h w) c->bs h w c',h=h1)
return k1 + k2 # 注意力融合