最近研究了一下focal transformer论文,复现了一下代码,研究的不是很透彻,写写自己的看法。看网上论文当中理论和算法部分介绍的很全面,没有什么复现代码的,然后自己复现一下,写一写,transformer也算第一次研究。
论文地址:https://arxiv.org/abs/2107.00641
代码地址:https://github.com/microsoft/Focal-Transformer
focal transformer可以说是在swin的基础上进行改动的,他与swin的代码是高度重合的。它最新颖的地方自然是新提出来的focal-self attention。
每一个stage也就是一个transformer模块,最大的创新点就是stage当中的focal self attention。这个attention我认为主要还是两部分组成:window_partition+window wise attention(这个也就是swin当中提出的w-msa,sw-msa代码里也给了,只不过我是没用。。。。)
概括一下,也就是 基础代码focal跟swin完全一样,不同的在哪呢?一个是focal要考虑不同的level,可以比较宽泛的理解为swin是在一个level上操作,那么focal在这个level上操作完后还要再加上粗粒度这个level上操作。第二个呢就是没有用swin新引入的sw-msa,而是用的w-msa(如果看代码的话其实也给到了,只不过是我没有用)
下面就分析一下这个FOCAL TRANSFORMER的代码(简单放一些比较重要的代码了,就不全放了)
1.patch_embeding
首先肯定在进入transformer之前要做的就是patch embeding,可以看到比如说你输入的图片是(1,512,224,224)也就是说batchsize=1,通道数是512,图片尺寸是224*224的,那么你在调用的时候参数就按你输入的定义,emd_dim这个参数就是看你想经过Patch处理以及拉长后想要多少维度(通道数)。
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
use_conv_embed (bool): Wherther use overlapped convolutional embedding layer. Default: False.
norm_layer (nn.Module, optional): Normalization layer. Default: None
use_pre_norm (bool): Whether use pre-normalization before projection. Default: False
is_stem (bool): Whether current patch embedding is stem. Default: False
"""
def __init__(self, img_size=(224, 224), patch_size=4, in_chans=3, embed_dim=96,
use_conv_embed=False, norm_layer=None, use_pre_norm=False, is_stem=False):
super().__init__()
ex:input:(8,64,256,256) 在下面这个调用下,输出就是(8,64*64,64)。其中batchsize不变,256*256在patch size为4的情况下,变为一幅图像分成了64个大小是4*4,64个通道的patch,如果这时候压缩成一维向量,长度应该是4*4*64.但是我们想要向量的长度是64,所以先经过线性层变成64。在压缩成一维向量。最终输出batchsize为8的64*64个长度为64的向量。
self.patch4 = PatchEmbed(img_size=(256, 256), patch_size=4, in_chans=64, embed_dim=64,
use_conv_embed=False, norm_layer=None, use_pre_norm=False, is_stem=False)
2.window_partition
具体窗口怎么分可以看论文中的解释
然后window size啊,focal level啊,focal window这仨最重要的参数我就用的文章里这个配置。(代码里也是按1357给好的,建议按这个。本来我图片输入512我想换window size=8,然后1248,但是报错了,具体咋改没研究,先按1357跑了)。那么这个window_partition主要就是在class focaltransformer block里面实现的,然后呢看看代码
class FocalTransformerBlock(nn.Module):
r""" Focal Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
expand_size (int): expand size at first focal level (finest level).
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
pool_method (str): window pooling method. Default: none, options: [none|fc|conv]
focal_level (int): number of focal levels. Default: 1.
focal_window (int): region size of focal attention. Default: 1
use_layerscale (bool): whether use layer scale for training stability. Default: False
layerscale_value (float): scaling value for layer scale. Default: 1e-4
"""
def __init__(self, dim, input_resolution, num_heads, window_size=8, expand_size=0, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, pool_method="none",
focal_level=1, focal_window=1, use_layerscale=False, layerscale_value=1e-4):
super().__init__()
focal transformer block最主要参数也就是这些,代码里都有注释,也都比较好理解。
整体流程如下吧:
第一步:是对前面patch_embeding输入的拉长向量把它变回特征图,然后做一个标准化
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
第二步:把feature map给pad到window size的整数倍,pad一下特征图避免除不开。然后考虑一下自己要不要用sw-msa,判断是否需要对特征图进行shift(看shift_size是不是大于0就完事了),也就是说如果用的是sw-msa,masks_all会有值,masks_all如果是0,那么就是用w-msa通过attn_mask是否为None判断进行W-MSA还是SW-MSA
# pad feature maps to multiples of window size
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
if pad_r > 0 or pad_b > 0:
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
B, H, W, C = x.shape
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
x_windows_all = [shifted_x]
x_window_masks_all = [self.attn_mask]
第三步:前面两步都跟swin一样吧,然后不一样的来了,因为是多个Level了。如果想加入粗粒度,就要保证你所给入的focal level>0,并且在pool_mthods不能选none。在这里这个池化方式也就是你在论文当中的sub-window pooling用什么方法。只有这样你才能选多个level。
第四步:(又跟swin一样了,也就是swin的第三步)将特征图切成一个个窗口。可以看到这个切窗口是在循环里的,如果你没有设对,你是相当于focal self attention压根就没有用到。
if self.focal_level > 1 and self.pool_method != "none":
# if we add coarser granularity and the pool method is not none
for k in range(self.focal_level-1):
window_size_glo = math.floor(self.window_size_glo / (2 ** k))#math.floor(返回一个小于或者等于的最大值)
pooled_h = math.ceil(H / self.window_size) * (2 ** k)
pooled_w = math.ceil(W / self.window_size) * (2 ** k)
H_pool = pooled_h * window_size_glo
W_pool = pooled_w * window_size_glo
x_level_k = shifted_x
# trim or pad shifted_x depending on the required size
if H > H_pool:
trim_t = (H - H_pool) // 2
trim_b = H - H_pool - trim_t
x_level_k = x_level_k[:, trim_t:-trim_b]
elif H < H_pool:
pad_t = (H_pool - H) // 2
pad_b = H_pool - H - pad_t
x_level_k = F.pad(x_level_k, (0,0,0,0,pad_t,pad_b))
if W > W_pool:
trim_l = (W - W_pool) // 2
trim_r = W - W_pool - trim_l
x_level_k = x_level_k[:, :, trim_l:-trim_r]
elif W < W_pool:
pad_l = (W_pool - W) // 2
pad_r = W_pool - W - pad_l
x_level_k = F.pad(x_level_k, (0,0,pad_l,pad_r))
#将特征图切成一个个的窗口
x_windows_noreshape = window_partition_noreshape(x_level_k.contiguous(), window_size_glo) # B, nw, nw, window_size, window_size, C
nWh, nWw = x_windows_noreshape.shape[1:3]
if self.pool_method == "mean":
x_windows_pooled = x_windows_noreshape.mean([3, 4]) # B, nWh, nWw, C
elif self.pool_method == "max":
x_windows_pooled = x_windows_noreshape.max(-2)[0].max(-2)[0].view(B, nWh, nWw, C) # B, nWh, nWw, C
elif self.pool_method == "fc":
x_windows_noreshape = x_windows_noreshape.view(B, nWh, nWw, window_size_glo*window_size_glo, C).transpose(3, 4) # B, nWh, nWw, C, wsize**2
x_windows_pooled = self.pool_layers[k](x_windows_noreshape).flatten(-2) # B, nWh, nWw, C
elif self.pool_method == "conv":
x_windows_noreshape = x_windows_noreshape.view(-1, window_size_glo, window_size_glo, C).permute(0, 3, 1, 2).contiguous() # B * nw * nw, C, wsize, wsize
x_windows_pooled = self.pool_layers[k](x_windows_noreshape).view(B, nWh, nWw, C) # B, nWh, nWw, C
x_windows_all += [x_windows_pooled]
x_window_masks_all += [None]
这样你最终得到的x_windows_all才是一个列表,你的level是几,里面就有几个值。
举个例子,他是怎么切的呢比如你输入的图像是56*56,咱们拿第一个模块举例子。那么你level 1的尺寸就是13*13,因为咱们上面不是用的是(1,13)吗。然后level 2的尺寸也就是7*7(咱们上面不是有个(7,7)吗,就是相当与原图中的7*7的感受野在level2中池化后变成1*1的感受野。这不也就是传说中的粗粒度吗,扩大感受野。)
第五步:就是attention了(咱们下面详细写)这块通过上面得到的masks_all是否为none,来选择是w-msa还是sw-msa
attn_windows = self.attn(x_windows_all, mask_all=x_window_masks_all) # nW*B, window_size*window_size, C
attn_windows = attn_windows[:, :self.window_size ** 2]
第六步:将各个窗口合并回来如果之前有做shift操作,此时进行reverse shift
,把之前的shift操作恢复。
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = shifted_x.permute(0, 3, 1, 2)
return x
然后看需要做做dropout和残差连接,再通过一层LayerNorm+全连接层,以及dropout和残差连接。就完成了一个stage了。
3.WindowAttention
然后这里前面那些就是算注意力位置偏移的就不写了,网上其他写的很清楚,然后我们可以看到是不是在下面这个代码里最开头的partition map和最末尾的给q*scale,这两步是不是就是swin的代码,主要问题就是中间这一大段是个啥玩意。具体让我解释,我也解释不出来。不过我写一下这一大段又臭又长的东西就是为了得到N,这个N也就是你所有level下的token总数。是不是只要知道这一大堆是求token总数的了,也就差不多了。
# partition q map
(q_windows, k_windows, v_windows) = map(
lambda t: window_partition(t, self.window_size[0]).view(
-1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads
).transpose(1, 2),
(q, k, v)
)
if self.expand_size > 0 and self.focal_level > 0:
(k_tl, v_tl) = map(
lambda t: torch.roll(t, shifts=(-self.expand_size, -self.expand_size), dims=(1, 2)), (k, v)
)
(k_tr, v_tr) = map(
lambda t: torch.roll(t, shifts=(-self.expand_size, self.expand_size), dims=(1, 2)), (k, v)
)
(k_bl, v_bl) = map(
lambda t: torch.roll(t, shifts=(self.expand_size, -self.expand_size), dims=(1, 2)), (k, v)
)
(k_br, v_br) = map(
lambda t: torch.roll(t, shifts=(self.expand_size, self.expand_size), dims=(1, 2)), (k, v)
)
(k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
lambda t: window_partition(t, self.window_size[0]).view(-1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads),
(k_tl, k_tr, k_bl, k_br)
)
(v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
lambda t: window_partition(t, self.window_size[0]).view(-1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads),
(v_tl, v_tr, v_bl, v_br)
)
k_rolled = torch.cat((k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows), 1).transpose(1, 2)
v_rolled = torch.cat((v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows), 1).transpose(1, 2)
# mask out tokens in current window
k_rolled = k_rolled[:, :, self.valid_ind_rolled]
v_rolled = v_rolled[:, :, self.valid_ind_rolled]
k_rolled = torch.cat((k_windows, k_rolled), 2)
v_rolled = torch.cat((v_windows, v_rolled), 2)
else:
k_rolled = k_windows; v_rolled = v_windows;
if self.pool_method != "none" and self.focal_level > 1:
k_pooled = []
v_pooled = []
for k in range(self.focal_level-1):
stride = 2**k
x_window_pooled = x_all[k+1] # B, nWh, nWw, C
nWh, nWw = x_window_pooled.shape[1:3]
# generate mask for pooled windows
mask = x_window_pooled.new(nWh, nWw).fill_(1)
unfolded_mask = self.unfolds[k](mask.unsqueeze(0).unsqueeze(1)).view(
1, 1, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
view(nWh*nWw // stride // stride, -1, 1)
if k > 0:
valid_ind_unfold_k = getattr(self, "valid_ind_unfold_{}".format(k))
unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
x_window_masks = x_window_masks.masked_fill(x_window_masks == 0, float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
mask_all[k+1] = x_window_masks
# generate k and v for pooled windows
qkv_pooled = self.qkv(x_window_pooled).reshape(B, nWh, nWw, 3, C).permute(3, 0, 4, 1, 2).contiguous()
k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2] # B, C, nWh, nWw
(k_pooled_k, v_pooled_k) = map(
lambda t: self.unfolds[k](t).view(
B, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
view(-1, self.unfolds[k].kernel_size[0]*self.unfolds[k].kernel_size[1], self.num_heads, C // self.num_heads).transpose(1, 2),
(k_pooled_k, v_pooled_k) # (B x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim
)
if k > 0:
(k_pooled_k, v_pooled_k) = map(
lambda t: t[:, :, valid_ind_unfold_k], (k_pooled_k, v_pooled_k)
)
k_pooled += [k_pooled_k]
v_pooled += [v_pooled_k]
k_all = torch.cat([k_rolled] + k_pooled, 2)
v_all = torch.cat([v_rolled] + v_pooled, 2)
else:
k_all = k_rolled
v_all = v_rolled
N = k_all.shape[-2]
q_windows = q_windows * self.scale
attn = (q_windows @ k_all.transpose(-2, -1))
举个例子在,还是这个参数
level 1 是不是token一共是13*13+169个,然后level2是不是一共7*7=49个(就是第二步池化后得到的尺寸)这个N=169+49=218.也就是下面这个图最后k,v的总的token。
但是最终实现代码的时候会发现,虽然是上面算出来218,实际上用了一共218+12个token。为什么呢,我也不是很明白哦,可以看看作者这个解释。
https://github.com/microsoft/Focal-Transformer/issues/6
我也没搞明白,反正现在就是无论你算出来多少个token加12就对了,如果有明白的希望可以教教我,谢谢啦。
额,感觉是不是主要内容也就是这些了,有问题可以再讨论!