【FOCAL TRANSFORMER】focal transformer和FOCAL SELF ATTENTION代码复现

最近研究了一下focal transformer论文,复现了一下代码,研究的不是很透彻,写写自己的看法。看网上论文当中理论和算法部分介绍的很全面,没有什么复现代码的,然后自己复现一下,写一写,transformer也算第一次研究。

 论文地址:https://arxiv.org/abs/2107.00641

代码地址:https://github.com/microsoft/Focal-Transformer

focal transformer可以说是在swin的基础上进行改动的,他与swin的代码是高度重合的。它最新颖的地方自然是新提出来的focal-self attention。

每一个stage也就是一个transformer模块,最大的创新点就是stage当中的focal self attention。这个attention我认为主要还是两部分组成:window_partition+window wise attention(这个也就是swin当中提出的w-msa,sw-msa代码里也给了,只不过我是没用。。。。)

概括一下,也就是 基础代码focal跟swin完全一样,不同的在哪呢?一个是focal要考虑不同的level,可以比较宽泛的理解为swin是在一个level上操作,那么focal在这个level上操作完后还要再加上粗粒度这个level上操作。第二个呢就是没有用swin新引入的sw-msa,而是用的w-msa(如果看代码的话其实也给到了,只不过是我没有用)

下面就分析一下这个FOCAL TRANSFORMER的代码(简单放一些比较重要的代码了,就不全放了)

1.patch_embeding

 首先肯定在进入transformer之前要做的就是patch embeding,可以看到比如说你输入的图片是(1,512,224,224)也就是说batchsize=1,通道数是512,图片尺寸是224*224的,那么你在调用的时候参数就按你输入的定义,emd_dim这个参数就是看你想经过Patch处理以及拉长后想要多少维度(通道数)。

class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding

    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        use_conv_embed (bool): Wherther use overlapped convolutional embedding layer. Default: False.
        norm_layer (nn.Module, optional): Normalization layer. Default: None 
        use_pre_norm (bool): Whether use pre-normalization before projection. Default: False
        is_stem (bool): Whether current patch embedding is stem. Default: False
    """

    def __init__(self, img_size=(224, 224), patch_size=4, in_chans=3, embed_dim=96, 
                    use_conv_embed=False, norm_layer=None, use_pre_norm=False, is_stem=False):
        super().__init__()

ex:input:(8,64,256,256) 在下面这个调用下,输出就是(8,64*64,64)。其中batchsize不变,256*256在patch size为4的情况下,变为一幅图像分成了64个大小是4*4,64个通道的patch,如果这时候压缩成一维向量,长度应该是4*4*64.但是我们想要向量的长度是64,所以先经过线性层变成64。在压缩成一维向量。最终输出batchsize为8的64*64个长度为64的向量。

self.patch4 = PatchEmbed(img_size=(256, 256), patch_size=4, in_chans=64, embed_dim=64,
                                 use_conv_embed=False, norm_layer=None, use_pre_norm=False, is_stem=False)

2.window_partition

具体窗口怎么分可以看论文中的解释

 

然后window size啊,focal level啊,focal window这仨最重要的参数我就用的文章里这个配置。(代码里也是按1357给好的,建议按这个。本来我图片输入512我想换window size=8,然后1248,但是报错了,具体咋改没研究,先按1357跑了)。那么这个window_partition主要就是在class focaltransformer block里面实现的,然后呢看看代码

class FocalTransformerBlock(nn.Module):
    r""" Focal Transformer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        expand_size (int): expand size at first focal level (finest level).
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm 
        pool_method (str): window pooling method. Default: none, options: [none|fc|conv]
        focal_level (int): number of focal levels. Default: 1. 
        focal_window (int): region size of focal attention. Default: 1
        use_layerscale (bool): whether use layer scale for training stability. Default: False
        layerscale_value (float): scaling value for layer scale. Default: 1e-4
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=8, expand_size=0, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, pool_method="none", 
                 focal_level=1, focal_window=1, use_layerscale=False, layerscale_value=1e-4):
        super().__init__()

 focal transformer block最主要参数也就是这些,代码里都有注释,也都比较好理解。

整体流程如下吧:

第一步:是对前面patch_embeding输入的拉长向量把它变回特征图,然后做一个标准化

    def forward(self, x):

        H, W = self.input_resolution

        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

第二步:把feature map给pad到window size的整数倍,pad一下特征图避免除不开。然后考虑一下自己要不要用sw-msa,判断是否需要对特征图进行shift(看shift_size是不是大于0就完事了),也就是说如果用的是sw-msa,masks_all会有值,masks_all如果是0,那么就是用w-msa通过attn_mask是否为None判断进行W-MSA还是SW-MSA

 # pad feature maps to multiples of window size 
        pad_l = pad_t = 0
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        if pad_r > 0 or pad_b > 0:
            x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        
        B, H, W, C = x.shape    
        
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x
        
        x_windows_all = [shifted_x]
        x_window_masks_all = [self.attn_mask]

第三步:前面两步都跟swin一样吧,然后不一样的来了,因为是多个Level了。如果想加入粗粒度,就要保证你所给入的focal level>0,并且在pool_mthods不能选none。在这里这个池化方式也就是你在论文当中的sub-window pooling用什么方法。只有这样你才能选多个level。
第四步:(又跟swin一样了,也就是swin的第三步)将特征图切成一个个窗口。可以看到这个切窗口是在循环里的,如果你没有设对,你是相当于focal self attention压根就没有用到。

        if self.focal_level > 1 and self.pool_method != "none": 
            # if we add coarser granularity and the pool method is not none
            for k in range(self.focal_level-1):     
                window_size_glo = math.floor(self.window_size_glo / (2 ** k))#math.floor(返回一个小于或者等于的最大值)
                pooled_h = math.ceil(H / self.window_size) * (2 ** k)
                pooled_w = math.ceil(W / self.window_size) * (2 ** k)
                H_pool = pooled_h * window_size_glo
                W_pool = pooled_w * window_size_glo

                x_level_k = shifted_x
                # trim or pad shifted_x depending on the required size
                if H > H_pool:
                    trim_t = (H - H_pool) // 2
                    trim_b = H - H_pool - trim_t
                    x_level_k = x_level_k[:, trim_t:-trim_b]
                elif H < H_pool:
                    pad_t = (H_pool - H) // 2
                    pad_b = H_pool - H - pad_t
                    x_level_k = F.pad(x_level_k, (0,0,0,0,pad_t,pad_b))
                
                if W > W_pool:
                    trim_l = (W - W_pool) // 2
                    trim_r = W - W_pool - trim_l
                    x_level_k = x_level_k[:, :, trim_l:-trim_r]
                elif W < W_pool:
                    pad_l = (W_pool - W) // 2
                    pad_r = W_pool - W - pad_l
                    x_level_k = F.pad(x_level_k, (0,0,pad_l,pad_r))
                #将特征图切成一个个的窗口
                x_windows_noreshape = window_partition_noreshape(x_level_k.contiguous(), window_size_glo) # B, nw, nw, window_size, window_size, C    
                nWh, nWw = x_windows_noreshape.shape[1:3]
                if self.pool_method == "mean":
                    x_windows_pooled = x_windows_noreshape.mean([3, 4]) # B, nWh, nWw, C
                elif self.pool_method == "max":
                    x_windows_pooled = x_windows_noreshape.max(-2)[0].max(-2)[0].view(B, nWh, nWw, C) # B, nWh, nWw, C                    
                elif self.pool_method == "fc":
                    x_windows_noreshape = x_windows_noreshape.view(B, nWh, nWw, window_size_glo*window_size_glo, C).transpose(3, 4) # B, nWh, nWw, C, wsize**2
                    x_windows_pooled = self.pool_layers[k](x_windows_noreshape).flatten(-2) # B, nWh, nWw, C                      
                elif self.pool_method == "conv":
                    x_windows_noreshape = x_windows_noreshape.view(-1, window_size_glo, window_size_glo, C).permute(0, 3, 1, 2).contiguous() # B * nw * nw, C, wsize, wsize
                    x_windows_pooled = self.pool_layers[k](x_windows_noreshape).view(B, nWh, nWw, C) # B, nWh, nWw, C           

                x_windows_all += [x_windows_pooled]
                x_window_masks_all += [None]

这样你最终得到的x_windows_all才是一个列表,你的level是几,里面就有几个值。

举个例子,他是怎么切的呢比如你输入的图像是56*56,咱们拿第一个模块举例子。那么你level 1的尺寸就是13*13,因为咱们上面不是用的是(1,13)吗。然后level 2的尺寸也就是7*7(咱们上面不是有个(7,7)吗,就是相当与原图中的7*7的感受野在level2中池化后变成1*1的感受野。这不也就是传说中的粗粒度吗,扩大感受野。)

第五步:就是attention了(咱们下面详细写)这块通过上面得到的masks_all是否为none,来选择是w-msa还是sw-msa

attn_windows = self.attn(x_windows_all, mask_all=x_window_masks_all)  # nW*B, window_size*window_size, C
        attn_windows = attn_windows[:, :self.window_size ** 2]

第六步:将各个窗口合并回来如果之前有做shift操作,此时进行reverse shift,把之前的shift操作恢复。

attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
        x = shifted_x.permute(0, 3, 1, 2)

        return x

 然后看需要做做dropout和残差连接,再通过一层LayerNorm+全连接层,以及dropout和残差连接。就完成了一个stage了。

3.WindowAttention

 然后这里前面那些就是算注意力位置偏移的就不写了,网上其他写的很清楚,然后我们可以看到是不是在下面这个代码里最开头的partition map和最末尾的给q*scale,这两步是不是就是swin的代码,主要问题就是中间这一大段是个啥玩意。具体让我解释,我也解释不出来。不过我写一下这一大段又臭又长的东西就是为了得到N,这个N也就是你所有level下的token总数。是不是只要知道这一大堆是求token总数的了,也就差不多了。

# partition q map
        (q_windows, k_windows, v_windows) = map(
            lambda t: window_partition(t, self.window_size[0]).view(
            -1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads
            ).transpose(1, 2), 
            (q, k, v)
        )

        if self.expand_size > 0 and self.focal_level > 0:
            (k_tl, v_tl) = map(
                lambda t: torch.roll(t, shifts=(-self.expand_size, -self.expand_size), dims=(1, 2)), (k, v)
            )
            (k_tr, v_tr) = map(
                lambda t: torch.roll(t, shifts=(-self.expand_size, self.expand_size), dims=(1, 2)), (k, v)
            )
            (k_bl, v_bl) = map(
                lambda t: torch.roll(t, shifts=(self.expand_size, -self.expand_size), dims=(1, 2)), (k, v)
            )
            (k_br, v_br) = map(
                lambda t: torch.roll(t, shifts=(self.expand_size, self.expand_size), dims=(1, 2)), (k, v)
            )        
            
            (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
                lambda t: window_partition(t, self.window_size[0]).view(-1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads), 
                (k_tl, k_tr, k_bl, k_br)
            )            
            (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
                lambda t: window_partition(t, self.window_size[0]).view(-1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads), 
                (v_tl, v_tr, v_bl, v_br)
            )
            k_rolled = torch.cat((k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows), 1).transpose(1, 2)
            v_rolled = torch.cat((v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows), 1).transpose(1, 2)
            
            # mask out tokens in current window
            k_rolled = k_rolled[:, :, self.valid_ind_rolled]
            v_rolled = v_rolled[:, :, self.valid_ind_rolled]
            k_rolled = torch.cat((k_windows, k_rolled), 2)
            v_rolled = torch.cat((v_windows, v_rolled), 2)
        else:
            k_rolled = k_windows; v_rolled = v_windows; 

        if self.pool_method != "none" and self.focal_level > 1:
            k_pooled = []
            v_pooled = []
            for k in range(self.focal_level-1):
                stride = 2**k
                x_window_pooled = x_all[k+1]  # B, nWh, nWw, C
                nWh, nWw = x_window_pooled.shape[1:3] 

                # generate mask for pooled windows
                mask = x_window_pooled.new(nWh, nWw).fill_(1)
                unfolded_mask = self.unfolds[k](mask.unsqueeze(0).unsqueeze(1)).view(
                    1, 1, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
                    view(nWh*nWw // stride // stride, -1, 1)

                if k > 0:
                    valid_ind_unfold_k = getattr(self, "valid_ind_unfold_{}".format(k))
                    unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]

                x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
                x_window_masks = x_window_masks.masked_fill(x_window_masks == 0, float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))            
                mask_all[k+1] = x_window_masks

                # generate k and v for pooled windows                
                qkv_pooled = self.qkv(x_window_pooled).reshape(B, nWh, nWw, 3, C).permute(3, 0, 4, 1, 2).contiguous()
                k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2]  # B, C, nWh, nWw


                (k_pooled_k, v_pooled_k) = map(
                    lambda t: self.unfolds[k](t).view(
                    B, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
                    view(-1, self.unfolds[k].kernel_size[0]*self.unfolds[k].kernel_size[1], self.num_heads, C // self.num_heads).transpose(1, 2), 
                    (k_pooled_k, v_pooled_k)  # (B x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim
                )

                if k > 0:                    
                    (k_pooled_k, v_pooled_k) = map(
                        lambda t: t[:, :, valid_ind_unfold_k], (k_pooled_k, v_pooled_k)
                    )

                k_pooled += [k_pooled_k]
                v_pooled += [v_pooled_k]
            k_all = torch.cat([k_rolled] + k_pooled, 2)
            v_all = torch.cat([v_rolled] + v_pooled, 2)
        else:
            k_all = k_rolled
            v_all = v_rolled
        N = k_all.shape[-2]
        
        q_windows = q_windows * self.scale
        attn = (q_windows @ k_all.transpose(-2, -1))

举个例子在,还是这个参数

level 1 是不是token一共是13*13+169个,然后level2是不是一共7*7=49个(就是第二步池化后得到的尺寸)这个N=169+49=218.也就是下面这个图最后k,v的总的token。

但是最终实现代码的时候会发现,虽然是上面算出来218,实际上用了一共218+12个token。为什么呢,我也不是很明白哦,可以看看作者这个解释。

https://github.com/microsoft/Focal-Transformer/issues/6

我也没搞明白,反正现在就是无论你算出来多少个token加12就对了,如果有明白的希望可以教教我,谢谢啦。

额,感觉是不是主要内容也就是这些了,有问题可以再讨论!

  • 4
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值