前言
Feature Pyramid Transformer出自于ECCV2020,本篇文章主要是把Transformer的主要架构self-attention进行了三种形式的演变( Self-Transformer,Grounding Transformer,Rendering Transformer),加入到了FPN的结构当中,起名叫FPT。
其主要架构如上图,本文中最主要的是它所提出的三个结构,Self-Transformer,Grounding Transformer,Rendering Transformer。以下简称ST,GT,RT。以上图b为例,其中ST,就是自己和自己进行self-attention,RT就是高位特征a,b之间的操作,GT,就是低维特征c,b之间的操作,通过这张图我们知道除了ST,a只有和b之间的GT。而c只有和b之间的RT。
ST
ST的公式如下所示
与self-attention的公式类似多了一个N,N主要应用在Weight的Fmos的这个公式中,文中引入了
改进了softmax,其他的和self-attention没有什么区别,具体实现方法:
def ST(x,H,W,C,N):
x=tf.layers.conv2d(x,C*3,1,strides=1, padding='same')
Q,K,V=tf.split(x, 3, axis=3)
Q=tf.reshape(Q,[-1,H*W,C])
K=tf.reshape(K,[-1,H*W,C])
V=tf.reshape(V,[-1,H*W,C])
K_trans=tf.transpose(K,[0,2,1])
result_all=tf.Variable(tf.zeros([1,H*W,H*W]))
for a in range(N):
WT=tf.Variable(tf.random_normal([H*W]))
k_mean,K_var=tf.nn.moments(K[:,:,int(C/N)*(a):int(C/N)*(a+1)],2)
result=tf.matmul(Q[:,:,int(C/N)*(a):int(C/N)*(a+1)],K_trans[:,int(C/N)*(a):int(C/N)*(a+1),:])
result_all=result_all+tf.nn.softmax(k_mean*WT)*tf.nn.softmax(result)
V=tf.matmul(result_all,V)
V=tf.reshape(V,[-1,H,W,C])
print(V)
N为个数,我们这里采用了for循环叠加的方案,完成的复现。
GT
GT主要采用的是欧式距离的方法,计算高维特征和低维特征,具体代码如下:
def GT(x,y,BZ,H,W,C,Expand,N):
x=tf.layers.conv2d_transpose(x,C*2,1,strides=Expand, padding='same')
K=tf.layers.conv2d(y,C,1,strides=1, padding='same')
Q,V=tf.split(x, 2, axis=3)
Q=tf.reshape(Q,[BZ,H*W,C])
K=tf.reshape(K,[BZ,H*W,C])
V=tf.reshape(V,[BZ,H*W,C])
print(Q,K,V)
result_all=tf.zeros([BZ,H*W,1])
for a in range(N):
WT=tf.Variable(tf.random_normal([H*W]))
k_mean,K_var=tf.nn.moments(K[:,:,int(C/N)*(a):int(C/N)*(a+1)],2)
result=-tf.square(Q[:,:,int(C/N)*(a):int(C/N)*(a+1)]-K[:,:,int(C/N)*(a):int(C/N)*(a+1)])
result=tf.nn.softmax(tf.reshape(k_mean*WT,[-1,H*W,1]))*tf.nn.softmax(result)
print(result)
result_all=tf.concat([result_all,result],-1)
print(result_all)
V=result_all[:,:,1:]*V
V=tf.reshape(V,[-1,H,W,C])
print(V)
其中x代表高维特征,y代表低维特征。
RT
RT公式如下,Q是图一中b的,K,V是c的
def RT(x,y,H,W,H1,W1,C,narrow): #x:high-level feature map y:low-level feature map
y=tf.layers.conv2d(y,C*2,1,strides=1, padding='same')
K,V =tf.split(y, 2, axis=3)
print(V)
Q=tf.layers.conv2d(x,C,1,strides=1, padding='same')
Q=tf.reshape(Q,[-1,H*W,C])
K=tf.reshape(K,[-1,H1*W1,C])
#
w=tf.keras.layers.GlobalAvgPool1D()(K)
Qatt=Q*tf.reshape(w,[-1,1,C])
Vdow=tf.layers.conv2d(V,C,3,strides=narrow, padding='same')
Qatt=tf.reshape(Qatt,[-1,H,W,C])
Xc=tf.layers.conv2d(Qatt,C,3,strides=1, padding='same')+Vdow
print(Xc)