使用transformer处理图像数据,需要照特定格式对矩阵分块,并拉伸flatten,在完成最后的卷积后,需要重新将token的channel重新reshape成图像格式。类似下图,将输入首先分块,然后拉伸为NxC的vector,然后重新reshape为图像格式,这里使用一通道简要说明。
小代码
def reshape():
h = 6
a = tf.random_uniform([h,h],maxval=40,dtype=tf.int32)
b = tf.reshape(a,[2,3,2,3])
c = tf.transpose(b,[0,2,1,3])
d = tf.reshape(tf.reshape(c,[-1,3,3]),[4,-1])
return a,d
def rereshape(x):
a = tf.reshape(x,[2,2,3,3])
b = tf.transpose(a,[0,2,1,3])
c = tf.reshape(tf.reshape(b,[6,2,3]),[6,-1])
return c
with tf.Session() as se:
a,b = se.run(reshape())
print('a:',a)
print('b:',b)
c = se.run(rereshape(b))
print('c:',c)
输出