研究摘要
Attention is sparse in vision transformers. We observe the final prediction in vision transformers is only based on a subset of most informative tokens, which is sufficient for accurate image recognition. Based on this observation, we propose a dynamic token sparsification framework to prune redundant tokens progressively and dynamically based on the input. Specifically, we devise a lightweight prediction module to estimate the importance score of each token given the current features. The module is added to different layers to prune redundant tokens hierarchically. To optimize the prediction module in an end-to-end manner, we propose an attention masking strategy to differentiably prune a token by blocking its interactions with other tokens. Benefiting from the nature of self-attention, the unstructured sparse tokens are still hardware friendly, which makes our framework easy to achieve actual speed-up. By hierarchically pruning 66% of the input tokens, our method greatly reduces 31% ∼ 37% FLOPs and improves the throughput by over 40% while the drop of accuracy is within 0.5% for various vision transformers. Equipped with the dynamic token sparsification framework, DynamicViT models can achieve very competitive complexity/accuracy trade-offs compared to state-of-the-art CNNs and vision transformers on ImageNet.
论文提出了一种在transformer模型中的动态令牌稀疏化框架,其目的对输入的token序列进行动态剪枝。作者设计了一个轻量级的预测网络,通过预测当前特征下每个令牌的重要性,对token序列进行相同比例的剪枝。作者将架构运用于DeiT和LV-ViT架构中,在精确率轻微降低的情况下大大降低了模型的计算量。
源码地址
Code is available at https://github.com/raoyongming/DynamicViT.
论文解读
The code shown below is from https://github.com/raoyongming/DynamicViT/blob/master/models/dyvit.py.
Hierarchical Token Sparsification with Prediction Modules
在架构中,作者引入一个由0或1构成的决策掩码矩阵来决策丢弃或保留某个token。
# initialize decistion matrix D
prev_decision = torch.ones(B, init_n, 1, dtype=x.dtype, device=x.device) # (batch_num, num_patches, 1)
首先以输入的token和最新的决策矩阵D序列作为输入,计算出local feature与global feature。
其中
The local feature encodes the information of a certain token while the global feature contains the context of the whole image, thus both of them are informative.
结合local feature与global feature,可以得到local-global embeddings并进一步预测丢弃/保持某个token的概率。
class PredictorLG(nn.Module):
def __init__(self, embed_dim=384):
super().__init__()
self.in_conv = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, embed_dim),
nn.GELU()
)
self.out_conv = nn.Sequential(
nn.Linear(embed_dim, embed_dim // 2),
nn.GELU(),
nn.Linear(embed_dim // 2, embed_dim // 4),
nn.GELU(),
# Zi = [Zi(local), Zi(global)]
nn.Linear(embed_dim // 4, 2), # (batch_num, num_patches ,2)
# Gumbel分布:x = argmax(log(p)+G),其中其中G就是服从Gumbel分布的噪音
nn.LogSoftmax(dim=-1)
)
def forward(self, x, policy): # policy: (batch_num, num_patches, 1)
x = self.in_conv(x) #(batch_num, num_patches, embed_dim)
B, N, C = x.size()
# Zlocal = MLP(x)
local_x = x[:,:, :C//2] #(batch_num, num_patches, embed_dim / 2)
# Zglobal = Agg(MLP(x), D)
global_x = (x[:,:, C//2:] * policy).sum(dim=1, keepdim=True) / torch.sum(policy, dim=1, keepdim=True) #(batch_num, 1, embed_dim / 2)
# Zi = [Zi(local), Zi(global)]
x = torch.cat([local_x, global_x.expand(B, N, C//2)], dim=-1) # (batch_num, num_patches, embed_dim)
# π = Softmax(MLP(z))
return self.out_conv(x) # (batch_num, num_patches ,2)
最后,我们利用Hadamard product对决策矩阵进行更新。
# use Hadamard product to update decistion matrix D
hard_keep_decision = F.gumbel_softmax(pred_score, hard=True)[:, :, 0:1] * prev_decision # (batch_num, num_patches, 1)
End-to-end Optimization with Attention Masking
This prediction module can be optimized jointly in an end-to-end manner together with the vision transformer backbone. To this end, two specialized strategies are adopted.The first one is to adopt Gumbel-Softmax to overcome the non-differentiable problem of sampling from a distribution so that it is possible to perform the end-to-end training. The second one is about how to apply this learned binary decision mask to prune the unnecessary tokens.
在通过对π进行采样以得到决策矩阵D时,为了解决不可微问题以便于进行端到端的训练,作者采用Gumbel-Softmax进行采样。
# x[:, 0]=cls_tokens
spatial_x = x[:, 1:] # (batch_num, patch_num, embed_dim)
pred_score = self.score_predictor[p_count](spatial_x, prev_decision).reshape(B, -1, 2) # (batch_num, num_patches, 2)
if self.training:
# use Hadamard product to update decistion matrix D
hard_keep_decision = F.gumbel_softmax(pred_score, hard=True)[:, :, 0:1] * prev_decision # (batch_num, num_patches, 1)
# store decision
out_pred_prob.append(hard_keep_decision.reshape(B, init_n)) # (batch_num, num_patches)
cls_policy = torch.ones(B, 1, 1, dtype=hard_keep_decision.dtype, device=hard_keep_decision.device)
policy = torch.cat([cls_policy, hard_keep_decision], dim=1) # (batch_num, num_patches+1 ,1)
x = blk(x, policy=policy) # (batch_num, num_patches+1 , embed_dim)
prev_decision = hard_keep_decision # (batch_num, num_patches, 1)
else:
score = pred_score[:,:,0] # (batch_num, num_keep_node, 1)
num_keep_node = int(init_n * self.token_ratio[p_count])
keep_policy = torch.argsort(score, dim=1, descending=True)[:, :num_keep_node] # (batch_num, num_keep_node)
cls_policy = torch.zeros(B, 1, dtype=keep_policy.dtype, device=keep_policy.device) # (batch_num, 1)
now_policy = torch.cat([cls_policy, keep_policy + 1], dim=1) # (batch_num, num_keep_node+1)
x = batch_index_select(x, now_policy) # (batch_num, num_keep_node+1, embed_dim)
prev_decision = batch_index_select(prev_decision, keep_policy) # (batch_num, num_keep_node, embed_dim)
x = blk(x) # (batch_num, num_patches+1 , embed_dim)
而在实际的剪枝训练中,如果直接丢弃Di=0的token,将导致batch中的不同样本数量不一致。而如果仅仅将token置0,并无法消去其在self-attention中的影响。
因此,作者采取一种精妙的策略能在保持样本token数不变的前提消除drop的token的影响。
def softmax_with_policy(self, attn, policy, eps=1e-6):
# policy = prev_decision
B, N, _ = policy.size() # (batch_num, num_patches+1, 1)
# attn = (q @ k.transpose(-2, -1)) * self.scale
B, H, N, N = attn.size() # (batch_num, num_heads, num_patches+1, num_patches+1)
attn_policy = policy.reshape(B, 1, 1, N) # (batch_num, 1, 1, num_patches+1)
eye = torch.eye(N, dtype=attn_policy.dtype, device=attn_policy.device).view(1, 1, N, N) # (1, 1, num_patches+1, num_patches+1)
# attain matrix G
attn_policy = attn_policy + (1.0 - attn_policy) * eye # (batch_num, 1, num_patches+1, num_patches+1)
max_att = torch.max(attn, dim=-1, keepdim=True)[0] # (batch_num, num_heads, num_patches+1, 1)
# for stable training, Preventing overflow
# softmax(x+c) = softmax(x)
attn = attn - max_att # (batch_num, num_heads, num_patches+1, num_patches+1)
# for stable training to attain matrix A
attn = attn.to(torch.float32).exp_() * attn_policy.to(torch.float32) # (batch_num, self.num_heads, num_patches+1, num_patches+1)
attn = (attn + eps/N) / (attn.sum(dim=-1, keepdim=True) + eps) # (batch_num, self.num_heads, num_patches+1, num_patches+1)
return attn.type_as(max_att)
Training and Inference
在实际的训练过程,loss function由多个方面构成。
首先,利用standard cross-entropy loss来计算模型预测与实际结果之间的损失。
其实,作者希望最小化因稀疏化模块引起的原模型的改变,因此作者将原始网络作为teacher mode,引入了两个特别的损失函数。
第一个损失函数用于描述稀疏化后的token序列与原序列的差异程度。
第二个损失函数用于描述稀疏化模块导致的预测结果的差异程序。
最后一个损失函数用于描述实际剪枝率与预期剪枝率的相差程序。
整个训练的损失函数由上述四个方面构成。
如果给定剪枝率ρ,可以利用
规划每一轮剪枝中保留信息的token数量为