Autofocusformer的空间聚类用在局部注意力有参考价值吗?

autofocus代码链接:https://github.com/apple/ml-autofocusformer

论文关键部分原理链接:click此处(我的上一篇文章)

《通过空间聚类、局部注意力和自适应采样三部分实现了聚类的计算。这是 AutoFocusFormer 的核心创新。》

0.前言

1)本文中会单拎出来几行关键代码,注:都可以在原本代码中找到,不是额外代码。
2)三个核心创新部分是配套我上一篇文章中论文的阅读来食用的。

1. 流程

看官方代码可以发现,/models/aff_transformer.py中的BasicLayer类说明了聚类的计算过程。
1)首先进行空间分块,将图像划分为多个 cluster,每个 cluster 内的 token 被认为离得较近。这通过 space_filling_cluster 函数实现。
2)找到每个 token 最近的几个 cluster,从这些附近的 cluster 收集 token 形成 neighborhood。这通过 knn_keops 和一些 gather 操作实现。
3)在 attention 计算中,每个 token 只会attend 到其 neighborhood 内的其他 token。这实现了局部的注意力计算。
4)最后通过 ClusterMerging 层,会根据每个 token 的重要性对其进行采样,保留重要的 token,丢弃不重要的 token。这实现了逐步下采样的效果。
在这里插入图片描述

2. 空间聚类逻辑代码

  这主要在BasicLayer的forward方法中实现。首先使用space_filling_cluster函数将图像划分为多个cluster,每个cluster内的token被认为离得较近。然后对每个token找到其最近的几个cluster,并从这些附近的cluster收集token形成neighborhood。这通过knn_keops和gather操作实现。

class BasicLayer(nn.Module):
    """ AutoFocusFormer layer for one stage.

    Args:
        dim (int): Number of input channels.
        out_dim (int): Number of output channels.
        cluster_size (int): Cluster size.
        nbhd_size (int): Neighbor size. If larger than or equal to number of tokens, perform global attention;
                            otherwise, rounded to the nearest multiples of cluster_size.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        alpha (float, optional): the weight to be multiplied with importance scores. Default: 4.0
        ds_rate (float, optional): downsampling rate, to be multiplied with the number of tokens. Default: 0.25
        reserve_on (bool, optional): whether to turn on reserve tokens in downsampling. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        layer_scale (float, optional): Layer scale initial parameter. Default: 0.0
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
    """

    def __init__(self, dim, out_dim, cluster_size, nbhd_size,
                 depth, num_heads, mlp_ratio,
                 alpha=4.0, ds_rate=0.25, reserve_on=True,
                 drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm,
                 layer_scale=0.0, downsample=None):

        super().__init__()
        self.dim = dim
        self.nbhd_size = nbhd_size
        self.cluster_size = cluster_size
        self.depth = depth

        # build blocks
        self.blocks = nn.ModuleList([
            ClusterTransformerBlock(dim=dim,
                                    num_heads=num_heads,
                                    mlp_ratio=mlp_ratio,
                                    drop=drop, attn_drop=attn_drop,
                                    drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                    layer_scale=layer_scale,
                                    norm_layer=norm_layer)
            for i in range(depth)])

        # merging layer
        if downsample is not None:
            self.downsample = downsample(dim=dim, out_dim=out_dim, norm_layer=norm_layer, alpha=alpha, ds_rate=ds_rate, reserve_on=reserve_on)
        else:
            self.downsample = None

        # cache the clustering result for the first feature map since it is on grid
        self.pos, self.cluster_mean_pos, self.member_idx, self.cluster_mask, self.reorder = None, None, None, None, None

        # fc for importance scores
        if downsample is not None:
            self.prob_net = nn.Linear(dim, 1)

    def forward(self, pos, feat, h, w, on_grid, stride):
        """
        Args:
            pos - b x n x 2, token positions
            feat - b x n x c, token features
            h,w - max height and width of token positions
            on_grid - bool, whether the tokens are still on grid; True for the first feature map
            stride - int, "stride" of the current token set; starts with 2, then doubles in each stage
        """
        b, n, d = pos.shape
        c = feat.shape[2]
        assert self.cluster_size > 0, 'self.cluster_size must be positive'

        if self.nbhd_size >= n:
            global_attn = True
            member_idx, cluster_mask = None, None
        else:
            global_attn = False
            k = int(math.ceil(n / float(self.cluster_size)))  # number of clusters
            nnc = min(int(round(self.nbhd_size / float(self.cluster_size))), k)  # number of nearest clusters
            nbhd_size = self.cluster_size * nnc
            self.nbhd_size = nbhd_size  # if not global attention, then nbhd size is rounded to nearest multiples of cluster

        if global_attn:
            rel_pos = (pos[:, None, :, :]+rel_pos_width) - pos[:, :, None, :]  # b x n x n x d
        else:
            if k == n:
                # if number of clusters equal to number of tokens
                cluster_mean_pos = pos
                member_idx = torch.arange(n, device=feat.device).long().reshape(1, n, 1).expand(b, -1, -1)  # b x n x 1
                cluster_mask = None
            else:
                # perform clustering
                if on_grid:
                    if self.cluster_mean_pos is None:
                        self.pos, self.cluster_mean_pos, self.member_idx, self.cluster_mask, self.reorder = space_filling_cluster(pos, self.cluster_size, h, w, no_reorder=False)
                    pos, cluster_mean_pos, member_idx, cluster_mask = self.pos[:b], self.cluster_mean_pos[:b], self.member_idx[:b], self.cluster_mask
                    # reorder the tokens so that tokens in same cluster are stored together
                    feat = feat[torch.arange(b).to(feat.device).repeat_interleave(n), self.reorder[:b].view(-1)].reshape(b, n, c)
                    if cluster_mask is not None:
                        cluster_mask = cluster_mask[:b]
                else:
                    pos, cluster_mean_pos, member_idx, cluster_mask, reorder = space_filling_cluster(pos, self.cluster_size, h, w, no_reorder=False)
                    # reorder the tokens so that tokens in same cluster are stored together
                    feat = feat[torch.arange(b).to(feat.device).repeat_interleave(n), reorder.view(-1)].reshape(b, n, c)

            assert member_idx.shape[1] == k and member_idx.shape[2] == self.cluster_size, "member_idx shape incorrect!"

            nearest_cluster = knn_keops(pos, cluster_mean_pos, nnc)  # b x n x nnc

            # collect neighbor indices from nearest clusters
            m = self.cluster_size
            member_idx = member_idx.gather(index=nearest_cluster.view(b, -1, 1).expand(-1, -1, m), dim=1).reshape(b, n, nbhd_size)  # b x n x nnc*m
            if cluster_mask is not None:
                cluster_mask = cluster_mask.gather(index=nearest_cluster.view(b, -1, 1).expand(-1, -1, m), dim=1).reshape(b, n, nbhd_size)
            pos_ = pos.gather(index=member_idx.view(b, -1, 1).expand(-1, -1, d), dim=1).reshape(b, n, nbhd_size, d)
            rel_pos = pos_ - (pos.unsqueeze(2)-rel_pos_width)  # b x n x nbhd_size x d

        # compute indices in the position embedding lookup table
        pe_idx = (rel_pos[..., 1] * table_width + rel_pos[..., 0]).long()

        for i_blk in range(len(self.blocks)):
            blk = self.blocks[i_blk]
            feat = blk(feat=feat,
                       member_idx=member_idx,
                       cluster_mask=cluster_mask,
                       pe_idx=pe_idx,
                       global_attn=global_attn)

        if self.downsample is not None:
            learned_prob = self.prob_net(feat).sigmoid()  # b x n x 1
            reserve_num = math.ceil(h/(stride*2)) * math.ceil(w/(stride*2))
            pos, feat = self.downsample(pos=pos, feat=feat,
                                        member_idx=member_idx, cluster_mask=cluster_mask,
                                        learned_prob=learned_prob, stride=stride,
                                        pe_idx=pe_idx, reserve_num=reserve_num)

        return pos, feat

    def extra_repr(self) -> str:
        return f"dim={self.dim}, depth={self.depth}"

  在point_utils.py文件中定义了space_filling_cluster的函数

def space_filling_cluster(pos, m, h, w, no_reorder=False, sf_type='', use_anchor=True):
    """
    The balanced clustering algorithm based on space-filling curves
    In the case where number of tokens not divisible by cluster size,
    the last cluster will have a few blank spots, indicated by the mask returned
    Args:
        pos - b x n x 2, positions of tokens
        m - int, target size of the clusters
        h,w - int, height and width
        no_reorder - bool, if True, return the clustering based on the original order of tokens;
                            otherwise, reorder the tokens so that the same cluster stays together
        sf_type - str, can be 'peano' or 'hilbert', or otherwise, horizontal scanlines w/ alternating
                        direction in each row by default
        use_anchor - bool, whether to use space-fiiling anchors or not; if False, directly compute
                            space-filling curves on the token positions
    Returns:
        pos - b x n x 2, returned only if no_reorder is False; the reordered position of tokens
        cluster_mean_pos - b x k x 2, the clustering centers
        member_idx - b x k x m, the indices of tokens in each cluster
        cluster_mask - b x k x m, the binary mask indicating the paddings in last cluster (0 if padding)
        pos_ranking - b x n x 1, returned only if no_reorder is False; i-th entry is the idx of the token
                                rank i in the new order
    """
    with torch.no_grad():
        pos = pos.detach()

        if pos.dtype != torch.float:
            pos = pos.to(torch.float)
        b, n, d = pos.shape

        k = int(math.ceil(n/m))

        if use_anchor:
            patch_len = (h*w/k)**0.5
            num_patch_h = int(round(h / patch_len))
            num_patch_w = int(round(w / patch_len))
            patch_len_h, patch_len_w = h / num_patch_h, w / num_patch_w
            if sf_type == 'peano':
                num_patch_h = max(3, int(3**round(math.log(num_patch_h, 3))))
                patch_len_h = h / num_patch_h
                num_patch_w = int(round(w / h * 3) * (num_patch_h / 3))
                patch_len_w = w / num_patch_w
            elif sf_type == 'hilbert':
                num_patch_h = max(2, int(2**round(math.log(num_patch_h, 2))))
                patch_len_h = h / num_patch_h
                num_patch_w = int(round(w / h * 2) * (num_patch_h / 2))
                patch_len_w = w / num_patch_w
            hs = torch.arange(0, num_patch_h, device=pos.device)
            ws = torch.arange(0, num_patch_w, device=pos.device)
            ys, xs = torch.meshgrid(hs, ws)
            grid_pos = torch.stack([xs, ys], dim=2)  # h x w x 2
            grid_pos = grid_pos.reshape(-1, 2)

            # sort the grid centers to one line
            if sf_type == 'peano':
                order_grid_idx, order_idx = calculate_peano_order(num_patch_h, num_patch_w, grid_pos.unsqueeze(0))
                order_grid_idx = order_grid_idx[0]
                order_idx = order_idx[0]
            elif sf_type == 'hilbert':
                order_grid_idx, order_idx = calculate_hilbert_order(num_patch_h, num_patch_w, grid_pos.unsqueeze(0))
                order_grid_idx = order_grid_idx[0]
                order_idx = order_idx[0]
            else:
                order_mask = torch.ones_like(ys)  # h x w
                order_mask[1::2] = -1
                order_mask = order_mask * xs
                order_mask = order_mask + ys*w
                order_mask[1::2] += (w-1)
                order_mask = order_mask.reshape(-1)
                order_idx = order_mask.sort()[1]
                order_idx_src = torch.arange(len(order_idx)).to(pos.device)
                order_grid_idx = torch.zeros_like(order_idx_src)
                order_grid_idx.scatter_(index=order_idx, dim=0, src=order_idx_src)

            ordered_grid = grid_pos[order_idx]
            patch_len_hw = torch.Tensor([patch_len_w, patch_len_h]).to(pos.device)

            init_pos_means = ordered_grid * patch_len_hw + patch_len_hw/2 - 0.5
            nump = ordered_grid.shape[0]

            prev_means = torch.zeros_like(init_pos_means)
            prev_means[1:] = init_pos_means[:nump-1].clone()
            prev_means[0] = prev_means[1] - (prev_means[2]-prev_means[1])  # float('inf')
            next_means = torch.zeros_like(init_pos_means)
            next_means[:nump-1] = init_pos_means[1:].clone()
            next_means[-1] = next_means[-2] + (next_means[-2]-next_means[-3])  # float('inf')

            mean_assignment = (pos / patch_len_hw).floor()
            mean_assignment = mean_assignment[..., 0] + mean_assignment[..., 1] * num_patch_w
            mean_assignment = order_grid_idx.unsqueeze(0).expand(b, -1).gather(index=mean_assignment.long(), dim=1).unsqueeze(2)  # b x n x 1

            prev_mean_assign = prev_means.unsqueeze(0).expand(b, -1, -1).gather(index=mean_assignment.expand(-1, -1, d), dim=1)  # b x n x d
            next_mean_assign = next_means.unsqueeze(0).expand(b, -1, -1).gather(index=mean_assignment.expand(-1, -1, d), dim=1)  # b x n x d
            dist_prev = (pos-prev_mean_assign).pow(2).sum(-1)  # b x n
            dist_next = (pos-next_mean_assign).pow(2).sum(-1)
            dist_ratio = dist_prev / (dist_next + 1e-5)

            pos_ranking = mean_assignment * (dist_ratio.max()+1) + dist_ratio.unsqueeze(2)
            pos_ranking = pos_ranking.sort(dim=1)[1]  # b x n x 1

        else:
            if sf_type == 'peano':
                _, pos_ranking = calculate_peano_order(h, w, pos)
            elif sf_type == 'hilbert':
                _, pos_ranking = calculate_hilbert_order(h, w, pos)
            else:
                hs = torch.arange(0, h, device=pos.device)
                ws = torch.arange(0, w, device=pos.device)
                ys, xs = torch.meshgrid(hs, ws)
                order_mask = torch.ones_like(ys)  # h x w
                order_mask[1::2] = -1
                order_mask = order_mask * xs
                order_mask = order_mask + ys*w
                order_mask[1::2] += (w-1)
                order_mask = order_mask.reshape(-1)
                pos_idx = pos[..., 0] + pos[..., 1] * w
                order_mask = order_mask.gather(index=pos_idx.long().reshape(-1), dim=0).reshape(b, n)
                pos_ranking = order_mask.sort()[1]
            pos_ranking = pos_ranking.unsqueeze(2)

        pos = pos.gather(index=pos_ranking.expand(-1, -1, d), dim=1)  # b x n x d

        if k*m == n:
            cluster_mask = None
            cluster_mean_pos = pos.reshape(b, k, -1, d).mean(2)
        else:
            pos_pad = torch.zeros(b, k*m, d, dtype=pos.dtype, device=pos.device)
            pos_pad[:, :n] = pos.clone()
            cluster_mask = torch.zeros(b, k*m, device=pos.device).long()
            cluster_mask[:, :n] = 1
            cluster_mask = cluster_mask.reshape(b, k, m)
            cluster_mean_pos = pos_pad.reshape(b, k, -1, d).sum(2) / cluster_mask.sum(2, keepdim=True)

        if no_reorder:
            if k*m == n:
                member_idx = pos_ranking.reshape(b, k, m)
            else:
                member_idx = torch.zeros(b, k*m, device=pos.device, dtype=torch.int64)
                member_idx[:, :n] = pos_ranking.squeeze(2)
                member_idx = member_idx.reshape(b, k, m)
            return cluster_mean_pos, member_idx, cluster_mask
        else:
            member_idx = torch.arange(k*m, device=pos.device)
            member_idx[n:] = 0
            member_idx = member_idx.unsqueeze(0).expand(b, -1)  # b x k*m
            member_idx = member_idx.reshape(b, k, m)

            return pos, cluster_mean_pos, member_idx, cluster_mask, pos_ranking
def knn_keops(query, database, k, return_dist=False):
    """
    Compute k-nearest neighbors using the Keops library
    Backward pass turned off; Keops does not provide backward pass for distance
    Args:
        query - b x n_ x c, the position of tokens looking for knn
        database - b x n x c, the candidate tokens for knn
        k - int, the nunmber of neighbors to be found
        return_dist - bool, whether to return distance to the neighbors
    Returns:
        nn_dix - b x n x k, the indices of the knn
        nn_dist - b x n x k, if return_dist, the distance to the knn
    """
    b, n, c = database.shape
    with torch.no_grad():
        query = query.detach()
        database = database.detach()
        # Keops does not support half precision
        if query.dtype != torch.float32:
            query = query.to(torch.float32)
        if database.dtype != torch.float32:
            database = database.to(torch.float32)
        from pykeops.torch import LazyTensor
        query_ = LazyTensor(query[:, None, :, :])
        database_ = LazyTensor(database[:, :, None, :])
        dist = ((query_-database_) ** 2).sum(-1) ** 0.5  # b x n x n_
    if return_dist:
        nn_dist, nn_idx = dist.Kmin_argKmin(k, dim=1)  # b x n_ x k
        return nn_idx, nn_dist
    else:
        nn_idx = dist.argKmin(k, dim=1)  # b x n_ x k
        return nn_idx

  gather操作是PyTorch中的一个张量操作函数,从输入张量的指定索引处收集元素。

3. 局部注意力

  在计算attention时,每个token只会attend到其neighborhood内的其他token。这通过在ClusterAttention中输入member_idx和cluster_mask来实现局部注意力。

class ClusterAttention(nn.Module):
    """
    Performs local attention on nearest clusters

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, num_heads, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.pos_dim = 2
        self.num_heads = num_heads

        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.q = nn.Linear(dim, dim)
        self.kv = nn.Linear(dim, 2*dim)
        self.softmax = nn.Softmax(dim=-1)

        self.blank_k = nn.Parameter(torch.randn(dim))
        self.blank_v = nn.Parameter(torch.randn(dim))

        self.pos_embed = nn.Linear(self.pos_dim+3, num_heads)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, feat, member_idx, cluster_mask, pe_idx, global_attn):
        """
        Args:
            feat - b x n x c, token features
            member_idx - b x n x nbhd, token idx in each local nbhd
            cluster_mask - b x n x nbhd, binary mask for valid tokens (1 if valid)
            pe_idx - b x n x nbhd, idx for the pre-computed position embedding lookup table
            global_attn - bool, whether to perform global attention
        """

        b, n, c = feat.shape
        c_ = c // self.num_heads
        assert c == self.dim, "dim does not accord to input"
        h = self.num_heads

        # get qkv
        q = self.q(feat)  # b x n x c
        q = q * self.scale
        kv = self.kv(feat)  # b x n x 2c

        # get attention
        if global_attn:
            q = q.reshape(b, n, h, -1).permute(0, 2, 1, 3)  # b x h x n x c_
            kv = kv.view(b, n, h, 2, c_).permute(3, 0, 2, 1, 4)  # 2 x b x h x n x c_
            key, v = kv[0], kv[1]
            attn = q @ key.transpose(-1, -2)  # b x h x n x n
            mask = None
        else:
            nbhd_size = member_idx.shape[-1]
            m = nbhd_size
            q = q.reshape(b, n, h, -1).permute(0, 2, 1, 3)
            kv = kv.view(b, n, h, 2, c_).permute(3, 0, 2, 1, 4)  # 2 x b x h x n x c_
            key, v = kv[0], kv[1]
            attn = CLUSTENQKFunction.apply(q, key, member_idx)  # b x h x n x m
            mask = cluster_mask
            if mask is not None:
                mask = mask.reshape(b, 1, n, m)

        # position embedding
        global pre_table
        if not pre_table.is_cuda:
            pre_table = pre_table.to(pe_idx.device)
        pe_table = self.pos_embed(pre_table)  # 111 x 111 x h for img_size 224x224

        pe_shape = pe_idx.shape
        pos_embed = pe_table.gather(index=pe_idx.view(-1, 1).expand(-1, h), dim=0).reshape(*(pe_shape), h).permute(0, 3, 1, 2)

        attn = attn + pos_embed

        if mask is not None:
            attn = attn + (1-mask)*(-100)

        # blank token
        blank_attn = (q * self.blank_k.reshape(1, h, 1, c_)).sum(-1, keepdim=True)  # b x h x n x 1
        attn = torch.cat([attn, blank_attn], dim=-1)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)

        blank_attn = attn[..., -1:]
        attn = attn[..., :-1]
        blank_v = blank_attn * self.blank_v.reshape(1, h, 1, c_)  # b x h x n x c_

        # aggregate v
        if global_attn:
            feat = (attn @ v).permute(0, 2, 1, 3).reshape(b, n, c)
            feat = feat + blank_v.permute(0, 2, 1, 3).reshape(b, n, c)
        else:
            feat = CLUSTENAVFunction.apply(attn, v, member_idx).permute(0, 2, 1, 3).reshape(b, n, c)
            feat = feat + blank_v.permute(0, 2, 1, 3).reshape(b, n, c)

        feat = self.proj(feat)
        feat = self.proj_drop(feat)

        return feat

    def extra_repr(self) -> str:
        return f'dim={self.dim}, num_heads={self.num_heads}'

  在ClusterAttention的forward函数中,有这么一段代码:

if global_attn:
  # 全局attention
else:
  # 局部attention
  attn = CLUSTENQKFunction.apply(q, key, member_idx)  
  mask = cluster_mask

  当进行局部注意力时,会调用CLUSTERQKFunction来计算attention。这个Function需要传入member_idx,它表示每个token的neighbor索引。CLUSTERQKFunction内部会根据member_idx来采样key向量,从而只计算局部注意力。cluster_mask用于在最后的softmax前屏蔽无效的neighbor,使得attn值很小,实现准确的局部聚焦。(CLUSTERQKFunction函数在/clusten/src/clusten.py中定义)

4. 自适应采样

  在BasicLayer的末尾,会根据每个token的重要性对其进行采样,保留重要的token,丢弃不重要的token。这是通过ClusterMerging层实现的。该层包含三部分:
(1) 根据位置先验和token的学习到的importance score计算每个token的保留概率。
(2) 根据保留概率采样出保留的token。
(3) 对保留的token所在的neighborhood进行采样合并,生成新的、更稀疏的feature map。

class ClusterMerging(nn.Module):
    r""" Adaptive Downsampling.

    Args:
        dim (int): Number of input channels.
        out_dim (int): Number of output channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
        alpha (float, optional): the weight to be multiplied with importance scores. Default: 4.0
        ds_rate (float, optional): downsampling rate, to be multiplied with the number of tokens. Default: 0.25
        reserve_on (bool, optional): whether to turn on reserve tokens in downsampling. Default: True
    """

    def __init__(self, dim, out_dim, norm_layer=nn.LayerNorm, alpha=4.0, ds_rate=0.25, reserve_on=True):
        super().__init__()
        self.dim = dim
        self.pos_dim = 2
        self.alpha = alpha
        self.ds_rate = ds_rate
        self.reserve_on = reserve_on

        # pointconv
        inner_ch = 4
        self.weight_net = nn.Sequential(
            nn.Linear(self.pos_dim+3, inner_ch, bias=True),
            nn.LayerNorm(inner_ch),
            nn.GELU()
        )
        self.norm = norm_layer(inner_ch*dim)
        self.linear = nn.Linear(dim*inner_ch, out_dim)

    def forward(self, pos, feat, member_idx, cluster_mask, learned_prob, stride, pe_idx, reserve_num):
        """
        Args:
            pos - b x n x 2, token positions
            feat - b x n x c, token features
            member_idx - b x n x nbhd, token idx in each local nbhd
            cluster_mask - b x n x nbhd, binary mask for valid tokens (1 if valid)
            learned_prob - b x n x 1, learned importance scores
            stride - int, "stride" of the current feature map, 2,4,8 for the 3 stages respectively
            pe_idx - b x n x nbhd, idx for the pre-computed position embedding lookup table
            reserve_num - int, number of tokens to be reserved
        """

        b, n, c = feat.shape
        d = pos.shape[2]

        keep_num = int(n*self.ds_rate)

        # grid prior
        if stride == 2:  # no ada ds yet, no need ada grid
            grid_prob = ((pos % stride).sum(-1) == 0).float()  # b x n
        else:
            _, min_dist = knn_keops(pos, pos, 2, return_dist=True)  # b x n x 2
            min_dist = min_dist[:, :, 1]  # b x n
            ada_stride = 2**(min_dist.log2().ceil()+1)  # b x n
            grid_prob = ((pos.long() % ada_stride.unsqueeze(2).long()).sum(-1) == 0).float()  # b x n

        final_prob = grid_prob

        # add importance score
        if learned_prob is not None:
            lp = learned_prob.detach().view(b, n)
            lp = lp * self.alpha
            final_prob = final_prob + lp

        # reserve points on a coarse grid
        if self.reserve_on:
            reserve_mask = ((pos % (stride*2)).sum(-1) == 0).float()  # b x n
            final_prob = final_prob + (reserve_mask*(-100))
            sample_num = keep_num - reserve_num
        else:
            sample_num = keep_num

        # select topk tokens as merging centers
        sample_idx = final_prob.topk(sample_num, dim=1, sorted=False)[1]  # b x n_

        if self.reserve_on:
            reserve_idx = reserve_mask.nonzero(as_tuple=True)[1].reshape(b, reserve_num)
            idx = torch.cat([sample_idx, reserve_idx], dim=-1).unsqueeze(2)  # b x n_ x 1
        else:
            idx = sample_idx.unsqueeze(2)

        n = idx.shape[1]
        assert n == keep_num, "n not equal to keep num!"

        # gather pos, nbhd, nbhd position embedding, nbhd importance scores for topk merging locations
        pos = pos.gather(index=idx.expand(-1, -1, d), dim=1)  # b x n' x d

        nbhd_size = member_idx.shape[-1]
        member_idx = member_idx.gather(index=idx.expand(-1, -1, nbhd_size), dim=1)  # b x n' x m
        pe_idx = pe_idx.gather(index=idx.expand(-1, -1, nbhd_size), dim=1)  # b x n' x m
        if cluster_mask is not None:
            cluster_mask = cluster_mask.gather(index=idx.expand(-1, -1, nbhd_size), dim=1)  # b x n' x m
        if learned_prob is not None:
            lp = learned_prob.gather(index=member_idx.view(b, -1, 1), dim=1).reshape(b, n, nbhd_size, 1)  # b x n x m x 1

        # pointconv weights
        global pre_table
        if not pre_table.is_cuda:
            pre_table = pre_table.to(pe_idx.device)
        weights_table = self.weight_net(pre_table)  # 111 x 111 x ic

        weight_shape = pe_idx.shape
        inner_ch = weights_table.shape[-1]
        weights = weights_table.gather(index=pe_idx.view(-1, 1).expand(-1, inner_ch), dim=0).reshape(*(weight_shape), inner_ch)

        if learned_prob is not None:
            if cluster_mask is not None:
                lp = lp * cluster_mask.unsqueeze(3)
            weights = weights * lp
        else:
            if cluster_mask is not None:
                weights = weights * cluster_mask.unsqueeze(3)

        # merge features
        feat = CLUSTENWFFunction.apply(weights, feat, member_idx.view(b, n, -1)).reshape(b, n, -1)  # b x n x ic*c
        feat = self.norm(feat)
        feat = self.linear(feat)  # b x n x 2c

        return pos, feat

  在ClusterMerging的forward函数中,首先计算一个grid_prob,这是基于token的位置先验得到的保留概率。然后从上一层的BasicLayer传递下来一个learned_prob,这是每个token的重要性分数。将二者组合可以得到每个token的最终保留概率final_prob。根据final_prob使用topk采样出保留的tokens,存储索引在sample_idx中。

sample_idx = final_prob.topk(sample_num, dim=1, sorted=False)[1]

  同时考虑到要在不同缩放下保留一定比例的anchors,会增加reserve_mask并组合到最终采样中。根据sample_idx索引,收集保留tokens的位置、附近邻居索引、附近邻居的位置embedding等信息。然后使用收集到的附近邻居对保留tokens进行加权平均,生成新的特征表示。这里使用了CLUSTENWFFunction来高效实现索引访问和特征合并。(CLUSTENWFFunction函数在/clusten/src/clusten.py中定义)
  综上,ClusterMerging通过计算保留概率、采样保留tokens以及邻域融合三步来逐步生成稀疏的特征图,这实现了自适应采样。

(注:1. 解码头不做介绍
   2. 复现说明:主要是在参考作者思路,取出关键部分,改进局部注意力,具体autofocusformer参考官方README.md

  • 28
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小马敲马

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值