实现了这个就实现了swin-block部分。我觉得原理还挺简单的,实现代码是真的有点麻烦。shift那里还好,主要是那个mask和relative_pos,reshape和transpose吐了
custom_function.py
from tensorflow.keras.layers import Input
import tensorflow as tf
def window_partition(x,window_size):
_,H,W,C = x.shape.as_list()
# print(H,W,C)
x = tf.reshape(x,shape=[-1,H//window_size,window_size,
W//window_size,window_size,C])
# -> B,nH,nW,w,w,C
x = tf.transpose(x,[0,1,3,2,4,5])
windows = tf.reshape(x,shape=[-1,window_size,window_size,C])
return windows
def window_reverse(windows,window_size,H,W,C):
# print(f'in window_reverse, {windows.shape}')
x = tf.reshape(windows,shape=[-1,H//window_size,W//window_size,
window_size,window_size,C])
x = tf.transpose(x,[0,1,3,2,4,5])
x = tf.reshape(x,shape=[-1,H,W,C])
return x
def drop_path(inputs,drop_prob,is_training):
if (not is_training) or (drop_prob==0.):
return inputs
keep_prob = 1.0 - drop_prob
random_tensor = keep_prob
shape = (tf.shape(inputs)[0],) + (1,)*(len(tf.shape(inputs))-1)
random_tensor +=(tf.random.uniform(shape,dtype=inputs.dtype))
binary_tensor = tf.floor(random_tensor)
output = tf.math.divide(inputs,keep_prob) * binary_tensor
return output
# if __name__ == '__main__':
# inputs = Input(shape=[56,56,96],batch_size=2)
# windows = window_partition(inputs,7)
# print(windows.shape)
# x = window_reverse(windows,7,56,56,96)
# print(x.shape)
custom_layer.py
from tensorflow.keras.layers import (Layer,Input,LayerNormalization,
Dense,Dropout,Conv2D,)
from tensorflow.keras.activations import gelu
import tensorflow as tf
import numpy as np
from custom_function import (drop_path)
from custom_function import window_partition, window_reverse
class MLPLayer(Layer):
def __init__(self,hidden_features=None,drop_rate=0.,**kwargs):
super(MLPLayer,self).__init__(**kwargs)
self.hidden_features = hidden_features
self.drop_rate = drop_rate
self.fc1 = Dense(self.hidden_features)
self.drop = Dropout(self.drop_rate)
def get_config(self):
config = super(MLPLayer,self).get_config()
config.update({"hidden_features":self.hidden_features,
"out_features":self.out_features,
"drop_rate":self.drop_rate})
return config
def build(self, input_shape):
self.out_features = input_shape[-1]
self.fc2 = Dense(self.out_features)
def call(self,inputs):
x = self.fc1(inputs)
x = gelu(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class WindowAttentionLayer(Layer):
def __init__(self,dim,window_size,num_heads,qkv_bias=True,
qk_scale=None,attn_drop_rate=0.,
proj_drop_rate=0.,**kwargs):
super(WindowAttentionLayer,self).__init__(**kwargs)
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
self.head_dim = dim//num_heads
self.scale = qk_scale or (self.head_dim ** (-0.5))
self.qkv_bias = qkv_bias
self.attn_drop_rate = attn_drop_rate
self.proj_drop_rate = proj_drop_rate
self.qkv = Dense(self.dim*3,use_bias=self.qkv_bias)
self.attn_drop = Dropout(self.attn_drop_rate)
self.proj = Dense(self.dim)
self.proj_drop = Dropout(self.proj_drop_rate)
def get_config(self):
config = super(WindowAttentionLayer,self).get_config()
config.update({"self.dim":self.dim,
"window_size":self.window_size,
"num_heads":self.num_heads,
"head_dim":self.head_dim,
"scale":self.scale,
"qkv_bias":self.qkv_bias,
"attn_drop_rate":self.attn_drop_rate,
"proj_drop_rate":self.proj_drop_rate})
return config
def build(self, input_shape):
self.relative_position_bias_table = self.add_weight(
shape=[(2*self.window_size[0]-1)*(2*self.window_size[1]-1),
self.num_heads],
initializer=tf.initializers.Zeros(),
trainable=True
)
coords_h = np.arange(self.window_size[0]) # 0-6
coords_w = np.arange(self.window_size[1])
coords = np.stack(np.meshgrid(coords_h,coords_w,indexing='ij'))
coords_flatten = coords.reshape(2,-1)
relative_coords = coords_flatten[:,:,None] - coords_flatten[:,None,:]
relative_coords = relative_coords.transpose([1,2,0])
relative_coords[:,:,0] +=self.window_size[0] - 1
relative_coords[:,:,1] +=self.window_size[1] - 1
relative_coords[:,:,0] *= 2*self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1).astype(np.int64)
self.relative_position_index = tf.Variable(
initial_value=tf.convert_to_tensor(relative_position_index),
trainable=False
)
self.built = True
def call(self,x,mask=None):
_,N,C = x.shape.as_list()
qkv = self.qkv(x)
q,k,v = tf.split(qkv,3,axis=-1) # -1,49,96
# -1,8,49,12
q = tf.transpose(tf.reshape(q,shape=[-1,N,self.num_heads,self.head_dim]),[0,2,1,3])
k = tf.transpose(tf.reshape(k,shape=[-1,N,self.num_heads,self.head_dim]),[0,2,1,3])
v = tf.transpose(tf.reshape(v,shape=[-1,N,self.num_heads,self.head_dim]),[0,2,1,3])
q = self.scale * q
# -> (-1, 8, 49, 49)
attn = tf.matmul(q,k,transpose_b=True)
# print(f'q*k之后的shape: {attn.shape}')
relative_position_bias = tf.gather(
self.relative_position_bias_table,
tf.reshape(self.relative_position_index,shape=[-1])
)
relative_position_bias = tf.reshape(relative_position_bias,
shape=[self.window_size[0]*self.window_size[1],
self.window_size[0]*self.window_size[1],
-1])
relative_position_bias = tf.transpose(relative_position_bias,
[2,0,1])
# print(f'relative_pos的shape: {relative_position_bias.shape}')
attn = attn + tf.expand_dims(relative_position_bias,axis=0)
# print(f'in winattn: {mask.shape}')
if type(mask) != type(None):
mask = tf.convert_to_tensor(mask)
nW = mask.shape[0]
attn = tf.reshape(attn,shape=[-1,nW,self.num_heads,N,N]) + \
tf.cast(tf.expand_dims(tf.expand_dims(mask,axis=1),axis=0),
attn.dtype)
attn = tf.reshape(attn,shape=[-1,self.num_heads,N,N])
attn = tf.nn.softmax(attn,axis=-1)
else:
attn = tf.nn.softmax(attn,axis=-1)
attn = self.attn_drop(attn)
# -> -1,49,8,12
x = tf.transpose((attn@v),[0,2,1,3])
# -> -1,49,96
x = tf.reshape(x,shape=[-1,N,C])
x = self.proj(x)
x = self.proj_drop(x)
return x
class DropPathLayer(Layer):
def __init__(self,drop_prob=None,**kwargs):
super(DropPathLayer,self).__init__(**kwargs)
self.drop_prob = drop_prob
def call(self,x,training=None):
return drop_path(x,self.drop_prob,training)
def get_config(self):
config = super(DropPathLayer,self).get_config()
config.update({"drop_prob":self.drop_prob})
return config
class SwinTransformerBlockLayer(Layer):
def __init__(self,dim,input_resolution,num_heads,window_size=7,
shift_size=0,mlp_ratio=4.,qkv_bias=True,qk_scale=None,
drop_rate=0.,attn_drop_rate=0.,drop_path_prob=0.,
**kwargs):
super(SwinTransformerBlockLayer,self).__init__(**kwargs)
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.qk_scale = qk_scale
self.drop_rate = drop_rate
self.attn_drop_rate = attn_drop_rate
self.drop_path_prob = drop_path_prob
if min(self.input_resolution) <= self.window_size:
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0<=self.shift_size<self.window_size,'偏移必须在0-window_size之间'
self.norm1 = LayerNormalization(epsilon=1e-5)
self.attn = WindowAttentionLayer(self.dim,(self.window_size,self.window_size),
self.num_heads,self.qkv_bias,self.qk_scale,
self.attn_drop_rate,self.drop_rate)
self.drop_path = DropPathLayer(self.drop_path_prob)
self.norm2 = LayerNormalization(epsilon=1e-5)
mlp_hidden_dim = int(dim*self.mlp_ratio)
self.mlp = MLPLayer(hidden_features=mlp_hidden_dim,
drop_rate=self.drop_rate)
def build(self,input_shape):
if self.shift_size > 0:
H,W = self.input_resolution
img_mask = np.zeros([1,H,W,1])
h_slices = (slice(0,-self.window_size),
slice(-self.window_size,-self.window_size),
slice(-self.shift_size,None))
w_slices = (slice(0,-self.window_size),
slice(-self.window_size,-self.window_size),
slice(-self.shift_size,None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:,h,w,:] = cnt
img_mask = tf.convert_to_tensor(img_mask)
mask_windows = window_partition(img_mask,self.window_size)
# print(f'in if {mask_windows.shape}')
mask_windows = tf.reshape(mask_windows,shape=[
-1,self.window_size*self.window_size
])
# -1,1,49 - -1,49,1 => -1,49,49
# print(f'in if {mask_windows.shape}')
attn_mask = tf.expand_dims(mask_windows,axis=1) - tf.expand_dims(mask_windows,axis=2)
attn_mask = tf.where(attn_mask!=0, -100.0, attn_mask)
attn_mask = tf.where(attn_mask==0, 0.0, attn_mask)
self.attn_mask = tf.Variable(initial_value=attn_mask,
trainable=False)
# print('in if',self.attn_mask.shape)
else:
self.attn_mask = None
# print('in else')
self.built = True
# print(f'in build, attn_mask={self.attn_mask}')
def get_config(self):
config = super(SwinTransformerBlockLayer,self).get_config()
config.update({"dim":self.dim,
"input_resolution":self.input_resolution,
"num_heads":self.num_heads,
"window_size":self.window_size,
"shift_size":self.shift_size,
"mlp_ratio":self.mlp_ratio,
"qkv_bias":self.qkv_bias,
"qk_scale":self.qk_scale,
"drop_rate":self.drop_rate,
"attn_drop_rate":self.attn_drop_rate,
"drop_path_prob":self.drop_path_prob,
})
return config
def call(self,x):
# print(f'in call: {self.attn_mask}')
H,W = self.input_resolution
_,L,C = x.shape.as_list()
assert L == H*W, 'input feature has wrong size.'
shortcut = x
x = self.norm1(x)
x = tf.reshape(x,shape=[-1,H,W,C])
# cyclic shift
if self.shift_size > 0:
shifted_x = tf.roll(x,shift=[-self.shift_size,-self.shift_size],
axis=[1,2])
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x,self.window_size)
x_windows = tf.reshape(x_windows,
shape=[-1,self.window_size*self.window_size,C])
# w-msa/sw-msa
# print('在做注意力之前的',self.attn_mask.shape)
attn_windows = self.attn(x_windows,mask=self.attn_mask)
# print(f'做完msa之后的shape: {attn_windows.shape}')
# merge windows
attn_windows = tf.reshape(attn_windows,
shape=[-1,self.window_size,self.window_size,C])
shifted_x = window_reverse(attn_windows,self.window_size,H,W,C)
# reverse cyclic shift
if self.shift_size > 0:
x = tf.roll(shifted_x,
shift=[self.shift_size,self.shift_size],
axis=[1,2])
else:
x = shifted_x
x = tf.reshape(x,shape=[-1,H*W,C])
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchMergingLayer(Layer):
def __init__(self,input_resolution,dim,**kwargs):
super(PatchMergingLayer,self).__init__(**kwargs)
self.input_resolution = input_resolution
self.dim = dim
self.norm = LayerNormalization(epsilon=1e-5)
self.reduction = Dense(2*self.dim,use_bias=False)
def get_config(self):
config = super(PatchMergingLayer,self).get_config()
config.update({"input_resolution":self.input_resolution,
"dim":self.dim})
return config
def call(self,x):
H,W = self.input_resolution
B,L,C = x.shape.as_list()
assert L==H*W, 'input feature has wrong size'
assert H%2==0 and W%2==0, f'x size ({H}*{W}) are not even.'
x = tf.reshape(x,shape=[-1,H,W,C])
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = tf.concat([x0, x1, x2, x3], axis=-1)
x = tf.reshape(x, shape=[-1, (H // 2) * (W // 2), 4 * C])
x = self.norm(x)
x = self.reduction(x)
return x
class PatchEmbeddingLayer(Layer):
def __init__(self,img_size=[224,224],patch_size=[4,4],
embed_dims=96,**kwargs):
super(PatchEmbeddingLayer,self).__init__(**kwargs)
self.img_size = img_size
self.patch_size = patch_size
self.embed_dims = embed_dims
patchs_resolution = [self.img_size[0]//self.patch_size[0],
self.img_size[1]//self.patch_size[1]]
self.patchs_resolution = patchs_resolution
self.num_patches = patchs_resolution[0] * patchs_resolution[1]
self.proj = Conv2D(self.embed_dims,self.patch_size,
self.patch_size)
def get_config(self):
config = super(PatchEmbeddingLayer,self).get_config()
config.update({"img_size":self.img_size,
"patch_size":self.patch_size,
"embed_dims":self.embed_dims,
"patchs_resolution":self.patchs_resolution,
"num_patches":self.num_patches})
return config
def call(self,x):
_,H,W,C = x.shape.as_list()
assert H==self.img_size[0] and W==self.img_size[1], \
f'input img size ({H}*{W}) does not match model ({self.img_size[0]}*{self.img_size[1]}).'
x = self.proj(x)
_,h,w,c = x.shape.as_list()
x = tf.reshape(x,shape=[-1,h*w,c])
return x
if __name__ == '__main__':
inputs = Input(shape=[224,224,3])
# 做patch_embedding
x = PatchEmbeddingLayer()(inputs)
print(f'patch_embedding之后的输出大小(b,56*56,96): {x.shape}')
# 经过一对swin transformer block
# shift_size=0; num_heads=3; window_size=7; mlp_ratio=4
x = SwinTransformerBlockLayer(96,[224//4,224//4],3,7,0,4)(x)
print(f'经过一个没有shift的STB之后的输出大小(b,56*56,96): {x.shape}')
# shift_size=3; num_heads=3; window_size=7; mlp_ratio=4
x = SwinTransformerBlockLayer(96,[224//4,224//4],3,7,3,4)(x)
print(f'经过一个经过shift的STB之后的输出大小(b,56*56,96): {x.shape}')
# 经过patch_mering,h,w减倍,通道加倍
x = PatchMergingLayer([224//4,224//4],96)(x)
print(f'经过patch_merging之后的输出大小(b,28*28,96*2): {x.shape}')
patch_embedding之后的输出大小(b,56*56,96): (None, 3136, 96)
经过一个没有shift的STB之后的输出大小(b,56*56,96): (None, 3136, 96)
经过一个经过shift的STB之后的输出大小(b,56*56,96): (None, 3136, 96)
经过patch_merging之后的输出大小(b,28*28,96*2): (None, 784, 192)