autofocus代码链接:https://github.com/apple/ml-autofocusformer
论文关键部分原理链接:click此处(我的上一篇文章)
《通过空间聚类、局部注意力和自适应采样三部分实现了聚类的计算。这是 AutoFocusFormer 的核心创新。》
0.前言
1)本文中会单拎出来几行关键代码,注:都可以在原本代码中找到,不是额外代码。
2)三个核心创新部分是配套我上一篇文章中论文的阅读来食用的。
1. 流程
看官方代码可以发现,/models/aff_transformer.py中的BasicLayer类说明了聚类的计算过程。
1)首先进行空间分块,将图像划分为多个 cluster,每个 cluster 内的 token 被认为离得较近。这通过 space_filling_cluster 函数实现。
2)找到每个 token 最近的几个 cluster,从这些附近的 cluster 收集 token 形成 neighborhood。这通过 knn_keops 和一些 gather 操作实现。
3)在 attention 计算中,每个 token 只会attend 到其 neighborhood 内的其他 token。这实现了局部的注意力计算。
4)最后通过 ClusterMerging 层,会根据每个 token 的重要性对其进行采样,保留重要的 token,丢弃不重要的 token。这实现了逐步下采样的效果。
2. 空间聚类逻辑代码
这主要在BasicLayer的forward方法中实现。首先使用space_filling_cluster函数将图像划分为多个cluster,每个cluster内的token被认为离得较近。然后对每个token找到其最近的几个cluster,并从这些附近的cluster收集token形成neighborhood。这通过knn_keops和gather操作实现。
class BasicLayer(nn.Module):
""" AutoFocusFormer layer for one stage.
Args:
dim (int): Number of input channels.
out_dim (int): Number of output channels.
cluster_size (int): Cluster size.
nbhd_size (int): Neighbor size. If larger than or equal to number of tokens, perform global attention;
otherwise, rounded to the nearest multiples of cluster_size.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
alpha (float, optional): the weight to be multiplied with importance scores. Default: 4.0
ds_rate (float, optional): downsampling rate, to be multiplied with the number of tokens. Default: 0.25
reserve_on (bool, optional): whether to turn on reserve tokens in downsampling. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
layer_scale (float, optional): Layer scale initial parameter. Default: 0.0
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
"""
def __init__(self, dim, out_dim, cluster_size, nbhd_size,
depth, num_heads, mlp_ratio,
alpha=4.0, ds_rate=0.25, reserve_on=True,
drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm,
layer_scale=0.0, downsample=None):
super().__init__()
self.dim = dim
self.nbhd_size = nbhd_size
self.cluster_size = cluster_size
self.depth = depth
# build blocks
self.blocks = nn.ModuleList([
ClusterTransformerBlock(dim=dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
layer_scale=layer_scale,
norm_layer=norm_layer)
for i in range(depth)])
# merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, out_dim=out_dim, norm_layer=norm_layer, alpha=alpha, ds_rate=ds_rate, reserve_on=reserve_on)
else:
self.downsample = None
# cache the clustering result for the first feature map since it is on grid
self.pos, self.cluster_mean_pos, self.member_idx, self.cluster_mask, self.reorder = None, None, None, None, None
# fc for importance scores
if downsample is not None:
self.prob_net = nn.Linear(dim, 1)
def forward(self, pos, feat, h, w, on_grid, stride):
"""
Args:
pos - b x n x 2, token positions
feat - b x n x c, token features
h,w - max height and width of token positions
on_grid - bool, whether the tokens are still on grid; True for the first feature map
stride - int, "stride" of the current token set; starts with 2, then doubles in each stage
"""
b, n, d = pos.shape
c = feat.shape[2]
assert self.cluster_size > 0, 'self.cluster_size must be positive'
if self.nbhd_size >= n:
global_attn = True
member_idx, cluster_mask = None, None
else:
global_attn = False
k = int(math.ceil(n / float(self.cluster_size))) # number of clusters
nnc = min(int(round(self.nbhd_size / float(self.cluster_size))), k) # number of nearest clusters
nbhd_size = self.cluster_size * nnc
self.nbhd_size = nbhd_size # if not global attention, then nbhd size is rounded to nearest multiples of cluster
if global_attn:
rel_pos = (pos[:, None, :, :]+rel_pos_width) - pos[:, :, None, :] # b x n x n x d
else:
if k == n:
# if number of clusters equal to number of tokens
cluster_mean_pos = pos
member_idx = torch.arange(n, device=feat.device).long().reshape(1, n, 1).expand(b, -1, -1) # b x n x 1
cluster_mask = None
else:
# perform clustering
if on_grid:
if self.cluster_mean_pos is None:
self.pos, self.cluster_mean_pos, self.member_idx, self.cluster_mask, self.reorder = space_filling_cluster(pos, self.cluster_size, h, w, no_reorder=False)
pos, cluster_mean_pos, member_idx, cluster_mask = self.pos[:b], self.cluster_mean_pos[:b], self.member_idx[:b], self.cluster_mask
# reorder the tokens so that tokens in same cluster are stored together
feat = feat[torch.arange(b).to(feat.device).repeat_interleave(n), self.reorder[:b].view(-1)].reshape(b, n, c)
if cluster_mask is not None:
cluster_mask = cluster_mask[:b]
else:
pos, cluster_mean_pos, member_idx, cluster_mask, reorder = space_filling_cluster(pos, self.cluster_size, h, w, no_reorder=False)
# reorder the tokens so that tokens in same cluster are stored together
feat = feat[torch.arange(b).to(feat.device).repeat_interleave(n), reorder.view(-1)].reshape(b, n, c)
assert member_idx.shape[1] == k and member_idx.shape[2] == self.cluster_size, "member_idx shape incorrect!"
nearest_cluster = knn_keops(pos, cluster_mean_pos, nnc) # b x n x nnc
# collect neighbor indices from nearest clusters
m = self.cluster_size
member_idx = member_idx.gather(index=nearest_cluster.view(b, -1, 1).expand(-1, -1, m), dim=1).reshape(b, n, nbhd_size) # b x n x nnc*m
if cluster_mask is not None:
cluster_mask = cluster_mask.gather(index=nearest_cluster.view(b, -1, 1).expand(-1, -1, m), dim=1).reshape(b, n, nbhd_size)
pos_ = pos.gather(index=member_idx.view(b, -1, 1).expand(-1, -1, d), dim=1).reshape(b, n, nbhd_size, d)
rel_pos = pos_ - (pos.unsqueeze(2)-rel_pos_width) # b x n x nbhd_size x d
# compute indices in the position embedding lookup table
pe_idx = (rel_pos[..., 1] * table_width + rel_pos[..., 0]).long()
for i_blk in range(len(self.blocks)):
blk = self.blocks[i_blk]
feat = blk(feat=feat,
member_idx=member_idx,
cluster_mask=cluster_mask,
pe_idx=pe_idx,
global_attn=global_attn)
if self.downsample is not None:
learned_prob = self.prob_net(feat).sigmoid() # b x n x 1
reserve_num = math.ceil(h/(stride*2)) * math.ceil(w/(stride*2))
pos, feat = self.downsample(pos=pos, feat=feat,
member_idx=member_idx, cluster_mask=cluster_mask,
learned_prob=learned_prob, stride=stride,
pe_idx=pe_idx, reserve_num=reserve_num)
return pos, feat
def extra_repr(self) -> str:
return f"dim={self.dim}, depth={self.depth}"
在point_utils.py文件中定义了space_filling_cluster的函数
def space_filling_cluster(pos, m, h, w, no_reorder=False, sf_type='', use_anchor=True):
"""
The balanced clustering algorithm based on space-filling curves
In the case where number of tokens not divisible by cluster size,
the last cluster will have a few blank spots, indicated by the mask returned
Args:
pos - b x n x 2, positions of tokens
m - int, target size of the clusters
h,w - int, height and width
no_reorder - bool, if True, return the clustering based on the original order of tokens;
otherwise, reorder the tokens so that the same cluster stays together
sf_type - str, can be 'peano' or 'hilbert', or otherwise, horizontal scanlines w/ alternating
direction in each row by default
use_anchor - bool, whether to use space-fiiling anchors or not; if False, directly compute
space-filling curves on the token positions
Returns:
pos - b x n x 2, returned only if no_reorder is False; the reordered position of tokens
cluster_mean_pos - b x k x 2, the clustering centers
member_idx - b x k x m, the indices of tokens in each cluster
cluster_mask - b x k x m, the binary mask indicating the paddings in last cluster (0 if padding)
pos_ranking - b x n x 1, returned only if no_reorder is False; i-th entry is the idx of the token
rank i in the new order
"""
with torch.no_grad():
pos = pos.detach()
if pos.dtype != torch.float:
pos = pos.to(torch.float)
b, n, d = pos.shape
k = int(math.ceil(n/m))
if use_anchor:
patch_len = (h*w/k)**0.5
num_patch_h = int(round(h / patch_len))
num_patch_w = int(round(w / patch_len))
patch_len_h, patch_len_w = h / num_patch_h, w / num_patch_w
if sf_type == 'peano':
num_patch_h = max(3, int(3**round(math.log(num_patch_h, 3))))
patch_len_h = h / num_patch_h
num_patch_w = int(round(w / h * 3) * (num_patch_h / 3))
patch_len_w = w / num_patch_w
elif sf_type == 'hilbert':
num_patch_h = max(2, int(2**round(math.log(num_patch_h, 2))))
patch_len_h = h / num_patch_h
num_patch_w = int(round(w / h * 2) * (num_patch_h / 2))
patch_len_w = w / num_patch_w
hs = torch.arange(0, num_patch_h, device=pos.device)
ws = torch.arange(0, num_patch_w, device=pos.device)
ys, xs = torch.meshgrid(hs, ws)
grid_pos = torch.stack([xs, ys], dim=2) # h x w x 2
grid_pos = grid_pos.reshape(-1, 2)
# sort the grid centers to one line
if sf_type == 'peano':
order_grid_idx, order_idx = calculate_peano_order(num_patch_h, num_patch_w, grid_pos.unsqueeze(0))
order_grid_idx = order_grid_idx[0]
order_idx = order_idx[0]
elif sf_type == 'hilbert':
order_grid_idx, order_idx = calculate_hilbert_order(num_patch_h, num_patch_w, grid_pos.unsqueeze(0))
order_grid_idx = order_grid_idx[0]
order_idx = order_idx[0]
else:
order_mask = torch.ones_like(ys) # h x w
order_mask[1::2] = -1
order_mask = order_mask * xs
order_mask = order_mask + ys*w
order_mask[1::2] += (w-1)
order_mask = order_mask.reshape(-1)
order_idx = order_mask.sort()[1]
order_idx_src = torch.arange(len(order_idx)).to(pos.device)
order_grid_idx = torch.zeros_like(order_idx_src)
order_grid_idx.scatter_(index=order_idx, dim=0, src=order_idx_src)
ordered_grid = grid_pos[order_idx]
patch_len_hw = torch.Tensor([patch_len_w, patch_len_h]).to(pos.device)
init_pos_means = ordered_grid * patch_len_hw + patch_len_hw/2 - 0.5
nump = ordered_grid.shape[0]
prev_means = torch.zeros_like(init_pos_means)
prev_means[1:] = init_pos_means[:nump-1].clone()
prev_means[0] = prev_means[1] - (prev_means[2]-prev_means[1]) # float('inf')
next_means = torch.zeros_like(init_pos_means)
next_means[:nump-1] = init_pos_means[1:].clone()
next_means[-1] = next_means[-2] + (next_means[-2]-next_means[-3]) # float('inf')
mean_assignment = (pos / patch_len_hw).floor()
mean_assignment = mean_assignment[..., 0] + mean_assignment[..., 1] * num_patch_w
mean_assignment = order_grid_idx.unsqueeze(0).expand(b, -1).gather(index=mean_assignment.long(), dim=1).unsqueeze(2) # b x n x 1
prev_mean_assign = prev_means.unsqueeze(0).expand(b, -1, -1).gather(index=mean_assignment.expand(-1, -1, d), dim=1) # b x n x d
next_mean_assign = next_means.unsqueeze(0).expand(b, -1, -1).gather(index=mean_assignment.expand(-1, -1, d), dim=1) # b x n x d
dist_prev = (pos-prev_mean_assign).pow(2).sum(-1) # b x n
dist_next = (pos-next_mean_assign).pow(2).sum(-1)
dist_ratio = dist_prev / (dist_next + 1e-5)
pos_ranking = mean_assignment * (dist_ratio.max()+1) + dist_ratio.unsqueeze(2)
pos_ranking = pos_ranking.sort(dim=1)[1] # b x n x 1
else:
if sf_type == 'peano':
_, pos_ranking = calculate_peano_order(h, w, pos)
elif sf_type == 'hilbert':
_, pos_ranking = calculate_hilbert_order(h, w, pos)
else:
hs = torch.arange(0, h, device=pos.device)
ws = torch.arange(0, w, device=pos.device)
ys, xs = torch.meshgrid(hs, ws)
order_mask = torch.ones_like(ys) # h x w
order_mask[1::2] = -1
order_mask = order_mask * xs
order_mask = order_mask + ys*w
order_mask[1::2] += (w-1)
order_mask = order_mask.reshape(-1)
pos_idx = pos[..., 0] + pos[..., 1] * w
order_mask = order_mask.gather(index=pos_idx.long().reshape(-1), dim=0).reshape(b, n)
pos_ranking = order_mask.sort()[1]
pos_ranking = pos_ranking.unsqueeze(2)
pos = pos.gather(index=pos_ranking.expand(-1, -1, d), dim=1) # b x n x d
if k*m == n:
cluster_mask = None
cluster_mean_pos = pos.reshape(b, k, -1, d).mean(2)
else:
pos_pad = torch.zeros(b, k*m, d, dtype=pos.dtype, device=pos.device)
pos_pad[:, :n] = pos.clone()
cluster_mask = torch.zeros(b, k*m, device=pos.device).long()
cluster_mask[:, :n] = 1
cluster_mask = cluster_mask.reshape(b, k, m)
cluster_mean_pos = pos_pad.reshape(b, k, -1, d).sum(2) / cluster_mask.sum(2, keepdim=True)
if no_reorder:
if k*m == n:
member_idx = pos_ranking.reshape(b, k, m)
else:
member_idx = torch.zeros(b, k*m, device=pos.device, dtype=torch.int64)
member_idx[:, :n] = pos_ranking.squeeze(2)
member_idx = member_idx.reshape(b, k, m)
return cluster_mean_pos, member_idx, cluster_mask
else:
member_idx = torch.arange(k*m, device=pos.device)
member_idx[n:] = 0
member_idx = member_idx.unsqueeze(0).expand(b, -1) # b x k*m
member_idx = member_idx.reshape(b, k, m)
return pos, cluster_mean_pos, member_idx, cluster_mask, pos_ranking
def knn_keops(query, database, k, return_dist=False):
"""
Compute k-nearest neighbors using the Keops library
Backward pass turned off; Keops does not provide backward pass for distance
Args:
query - b x n_ x c, the position of tokens looking for knn
database - b x n x c, the candidate tokens for knn
k - int, the nunmber of neighbors to be found
return_dist - bool, whether to return distance to the neighbors
Returns:
nn_dix - b x n x k, the indices of the knn
nn_dist - b x n x k, if return_dist, the distance to the knn
"""
b, n, c = database.shape
with torch.no_grad():
query = query.detach()
database = database.detach()
# Keops does not support half precision
if query.dtype != torch.float32:
query = query.to(torch.float32)
if database.dtype != torch.float32:
database = database.to(torch.float32)
from pykeops.torch import LazyTensor
query_ = LazyTensor(query[:, None, :, :])
database_ = LazyTensor(database[:, :, None, :])
dist = ((query_-database_) ** 2).sum(-1) ** 0.5 # b x n x n_
if return_dist:
nn_dist, nn_idx = dist.Kmin_argKmin(k, dim=1) # b x n_ x k
return nn_idx, nn_dist
else:
nn_idx = dist.argKmin(k, dim=1) # b x n_ x k
return nn_idx
gather操作是PyTorch中的一个张量操作函数,从输入张量的指定索引处收集元素。
3. 局部注意力
在计算attention时,每个token只会attend到其neighborhood内的其他token。这通过在ClusterAttention中输入member_idx和cluster_mask来实现局部注意力。
class ClusterAttention(nn.Module):
"""
Performs local attention on nearest clusters
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, num_heads, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.pos_dim = 2
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.q = nn.Linear(dim, dim)
self.kv = nn.Linear(dim, 2*dim)
self.softmax = nn.Softmax(dim=-1)
self.blank_k = nn.Parameter(torch.randn(dim))
self.blank_v = nn.Parameter(torch.randn(dim))
self.pos_embed = nn.Linear(self.pos_dim+3, num_heads)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, feat, member_idx, cluster_mask, pe_idx, global_attn):
"""
Args:
feat - b x n x c, token features
member_idx - b x n x nbhd, token idx in each local nbhd
cluster_mask - b x n x nbhd, binary mask for valid tokens (1 if valid)
pe_idx - b x n x nbhd, idx for the pre-computed position embedding lookup table
global_attn - bool, whether to perform global attention
"""
b, n, c = feat.shape
c_ = c // self.num_heads
assert c == self.dim, "dim does not accord to input"
h = self.num_heads
# get qkv
q = self.q(feat) # b x n x c
q = q * self.scale
kv = self.kv(feat) # b x n x 2c
# get attention
if global_attn:
q = q.reshape(b, n, h, -1).permute(0, 2, 1, 3) # b x h x n x c_
kv = kv.view(b, n, h, 2, c_).permute(3, 0, 2, 1, 4) # 2 x b x h x n x c_
key, v = kv[0], kv[1]
attn = q @ key.transpose(-1, -2) # b x h x n x n
mask = None
else:
nbhd_size = member_idx.shape[-1]
m = nbhd_size
q = q.reshape(b, n, h, -1).permute(0, 2, 1, 3)
kv = kv.view(b, n, h, 2, c_).permute(3, 0, 2, 1, 4) # 2 x b x h x n x c_
key, v = kv[0], kv[1]
attn = CLUSTENQKFunction.apply(q, key, member_idx) # b x h x n x m
mask = cluster_mask
if mask is not None:
mask = mask.reshape(b, 1, n, m)
# position embedding
global pre_table
if not pre_table.is_cuda:
pre_table = pre_table.to(pe_idx.device)
pe_table = self.pos_embed(pre_table) # 111 x 111 x h for img_size 224x224
pe_shape = pe_idx.shape
pos_embed = pe_table.gather(index=pe_idx.view(-1, 1).expand(-1, h), dim=0).reshape(*(pe_shape), h).permute(0, 3, 1, 2)
attn = attn + pos_embed
if mask is not None:
attn = attn + (1-mask)*(-100)
# blank token
blank_attn = (q * self.blank_k.reshape(1, h, 1, c_)).sum(-1, keepdim=True) # b x h x n x 1
attn = torch.cat([attn, blank_attn], dim=-1)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
blank_attn = attn[..., -1:]
attn = attn[..., :-1]
blank_v = blank_attn * self.blank_v.reshape(1, h, 1, c_) # b x h x n x c_
# aggregate v
if global_attn:
feat = (attn @ v).permute(0, 2, 1, 3).reshape(b, n, c)
feat = feat + blank_v.permute(0, 2, 1, 3).reshape(b, n, c)
else:
feat = CLUSTENAVFunction.apply(attn, v, member_idx).permute(0, 2, 1, 3).reshape(b, n, c)
feat = feat + blank_v.permute(0, 2, 1, 3).reshape(b, n, c)
feat = self.proj(feat)
feat = self.proj_drop(feat)
return feat
def extra_repr(self) -> str:
return f'dim={self.dim}, num_heads={self.num_heads}'
在ClusterAttention的forward函数中,有这么一段代码:
if global_attn:
# 全局attention
else:
# 局部attention
attn = CLUSTENQKFunction.apply(q, key, member_idx)
mask = cluster_mask
当进行局部注意力时,会调用CLUSTERQKFunction来计算attention。这个Function需要传入member_idx,它表示每个token的neighbor索引。CLUSTERQKFunction内部会根据member_idx来采样key向量,从而只计算局部注意力。cluster_mask用于在最后的softmax前屏蔽无效的neighbor,使得attn值很小,实现准确的局部聚焦。(CLUSTERQKFunction函数在/clusten/src/clusten.py中定义)
4. 自适应采样
在BasicLayer的末尾,会根据每个token的重要性对其进行采样,保留重要的token,丢弃不重要的token。这是通过ClusterMerging层实现的。该层包含三部分:
(1) 根据位置先验和token的学习到的importance score计算每个token的保留概率。
(2) 根据保留概率采样出保留的token。
(3) 对保留的token所在的neighborhood进行采样合并,生成新的、更稀疏的feature map。
class ClusterMerging(nn.Module):
r""" Adaptive Downsampling.
Args:
dim (int): Number of input channels.
out_dim (int): Number of output channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
alpha (float, optional): the weight to be multiplied with importance scores. Default: 4.0
ds_rate (float, optional): downsampling rate, to be multiplied with the number of tokens. Default: 0.25
reserve_on (bool, optional): whether to turn on reserve tokens in downsampling. Default: True
"""
def __init__(self, dim, out_dim, norm_layer=nn.LayerNorm, alpha=4.0, ds_rate=0.25, reserve_on=True):
super().__init__()
self.dim = dim
self.pos_dim = 2
self.alpha = alpha
self.ds_rate = ds_rate
self.reserve_on = reserve_on
# pointconv
inner_ch = 4
self.weight_net = nn.Sequential(
nn.Linear(self.pos_dim+3, inner_ch, bias=True),
nn.LayerNorm(inner_ch),
nn.GELU()
)
self.norm = norm_layer(inner_ch*dim)
self.linear = nn.Linear(dim*inner_ch, out_dim)
def forward(self, pos, feat, member_idx, cluster_mask, learned_prob, stride, pe_idx, reserve_num):
"""
Args:
pos - b x n x 2, token positions
feat - b x n x c, token features
member_idx - b x n x nbhd, token idx in each local nbhd
cluster_mask - b x n x nbhd, binary mask for valid tokens (1 if valid)
learned_prob - b x n x 1, learned importance scores
stride - int, "stride" of the current feature map, 2,4,8 for the 3 stages respectively
pe_idx - b x n x nbhd, idx for the pre-computed position embedding lookup table
reserve_num - int, number of tokens to be reserved
"""
b, n, c = feat.shape
d = pos.shape[2]
keep_num = int(n*self.ds_rate)
# grid prior
if stride == 2: # no ada ds yet, no need ada grid
grid_prob = ((pos % stride).sum(-1) == 0).float() # b x n
else:
_, min_dist = knn_keops(pos, pos, 2, return_dist=True) # b x n x 2
min_dist = min_dist[:, :, 1] # b x n
ada_stride = 2**(min_dist.log2().ceil()+1) # b x n
grid_prob = ((pos.long() % ada_stride.unsqueeze(2).long()).sum(-1) == 0).float() # b x n
final_prob = grid_prob
# add importance score
if learned_prob is not None:
lp = learned_prob.detach().view(b, n)
lp = lp * self.alpha
final_prob = final_prob + lp
# reserve points on a coarse grid
if self.reserve_on:
reserve_mask = ((pos % (stride*2)).sum(-1) == 0).float() # b x n
final_prob = final_prob + (reserve_mask*(-100))
sample_num = keep_num - reserve_num
else:
sample_num = keep_num
# select topk tokens as merging centers
sample_idx = final_prob.topk(sample_num, dim=1, sorted=False)[1] # b x n_
if self.reserve_on:
reserve_idx = reserve_mask.nonzero(as_tuple=True)[1].reshape(b, reserve_num)
idx = torch.cat([sample_idx, reserve_idx], dim=-1).unsqueeze(2) # b x n_ x 1
else:
idx = sample_idx.unsqueeze(2)
n = idx.shape[1]
assert n == keep_num, "n not equal to keep num!"
# gather pos, nbhd, nbhd position embedding, nbhd importance scores for topk merging locations
pos = pos.gather(index=idx.expand(-1, -1, d), dim=1) # b x n' x d
nbhd_size = member_idx.shape[-1]
member_idx = member_idx.gather(index=idx.expand(-1, -1, nbhd_size), dim=1) # b x n' x m
pe_idx = pe_idx.gather(index=idx.expand(-1, -1, nbhd_size), dim=1) # b x n' x m
if cluster_mask is not None:
cluster_mask = cluster_mask.gather(index=idx.expand(-1, -1, nbhd_size), dim=1) # b x n' x m
if learned_prob is not None:
lp = learned_prob.gather(index=member_idx.view(b, -1, 1), dim=1).reshape(b, n, nbhd_size, 1) # b x n x m x 1
# pointconv weights
global pre_table
if not pre_table.is_cuda:
pre_table = pre_table.to(pe_idx.device)
weights_table = self.weight_net(pre_table) # 111 x 111 x ic
weight_shape = pe_idx.shape
inner_ch = weights_table.shape[-1]
weights = weights_table.gather(index=pe_idx.view(-1, 1).expand(-1, inner_ch), dim=0).reshape(*(weight_shape), inner_ch)
if learned_prob is not None:
if cluster_mask is not None:
lp = lp * cluster_mask.unsqueeze(3)
weights = weights * lp
else:
if cluster_mask is not None:
weights = weights * cluster_mask.unsqueeze(3)
# merge features
feat = CLUSTENWFFunction.apply(weights, feat, member_idx.view(b, n, -1)).reshape(b, n, -1) # b x n x ic*c
feat = self.norm(feat)
feat = self.linear(feat) # b x n x 2c
return pos, feat
在ClusterMerging的forward函数中,首先计算一个grid_prob,这是基于token的位置先验得到的保留概率。然后从上一层的BasicLayer传递下来一个learned_prob,这是每个token的重要性分数。将二者组合可以得到每个token的最终保留概率final_prob。根据final_prob使用topk采样出保留的tokens,存储索引在sample_idx中。
sample_idx = final_prob.topk(sample_num, dim=1, sorted=False)[1]
同时考虑到要在不同缩放下保留一定比例的anchors,会增加reserve_mask并组合到最终采样中。根据sample_idx索引,收集保留tokens的位置、附近邻居索引、附近邻居的位置embedding等信息。然后使用收集到的附近邻居对保留tokens进行加权平均,生成新的特征表示。这里使用了CLUSTENWFFunction来高效实现索引访问和特征合并。(CLUSTENWFFunction函数在/clusten/src/clusten.py中定义)
综上,ClusterMerging通过计算保留概率、采样保留tokens以及邻域融合三步来逐步生成稀疏的特征图,这实现了自适应采样。
(注:1. 解码头不做介绍
2. 复现说明:主要是在参考作者思路,取出关键部分,改进局部注意力,具体autofocusformer参考官方README.md)