DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification

研究摘要

Attention is sparse in vision transformers. We observe the final prediction in vision transformers is only based on a subset of most informative tokens, which is sufficient for accurate image recognition. Based on this observation, we propose a dynamic token sparsification framework to prune redundant tokens progressively and dynamically based on the input. Specifically, we devise a lightweight prediction module to estimate the importance score of each token given the current features. The module is added to different layers to prune redundant tokens hierarchically. To optimize the prediction module in an end-to-end manner, we propose an attention masking strategy to differentiably prune a token by blocking its interactions with other tokens. Benefiting from the nature of self-attention, the unstructured sparse tokens are still hardware friendly, which makes our framework easy to achieve actual speed-up. By hierarchically pruning 66% of the input tokens, our method greatly reduces 31% ∼ 37% FLOPs and improves the throughput by over 40% while the drop of accuracy is within 0.5% for various vision transformers. Equipped with the dynamic token sparsification framework, DynamicViT models can achieve very competitive complexity/accuracy trade-offs compared to state-of-the-art CNNs and vision transformers on ImageNet.

论文提出了一种在transformer模型中的动态令牌稀疏化框架,其目的对输入的token序列进行动态剪枝。作者设计了一个轻量级的预测网络,通过预测当前特征下每个令牌的重要性,对token序列进行相同比例的剪枝。作者将架构运用于DeiTLV-ViT架构中,在精确率轻微降低的情况下大大降低了模型的计算量。
网络架构


源码地址

Code is available at https://github.com/raoyongming/DynamicViT.


论文解读

The code shown below is from https://github.com/raoyongming/DynamicViT/blob/master/models/dyvit.py.

Hierarchical Token Sparsification with Prediction Modules

在架构中,作者引入一个由01构成的决策掩码矩阵来决策丢弃或保留某个token

决策掩码矩阵

# initialize decistion matrix D
prev_decision = torch.ones(B, init_n, 1, dtype=x.dtype, device=x.device)  # (batch_num, num_patches, 1)

首先以输入的token和最新的决策矩阵D序列作为输入,计算出local featureglobal feature

局部特征

全局特征
其中
信息集合

The local feature encodes the information of a certain token while the global feature contains the context of the whole image, thus both of them are informative.

结合local featureglobal feature,可以得到local-global embeddings并进一步预测丢弃/保持某个token的概率。

决策计算

class PredictorLG(nn.Module):
    def __init__(self, embed_dim=384):
        super().__init__()
        self.in_conv = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim),
            nn.GELU()
        )

        self.out_conv = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.GELU(),
            nn.Linear(embed_dim // 2, embed_dim // 4),
            nn.GELU(),
            # Zi = [Zi(local), Zi(global)]
            nn.Linear(embed_dim // 4, 2),  # (batch_num, num_patches ,2)
            # Gumbel分布:x = argmax(log(p)+G),其中其中G就是服从Gumbel分布的噪音
            nn.LogSoftmax(dim=-1)
        )

    def forward(self, x, policy):  # policy: (batch_num, num_patches, 1)
        x = self.in_conv(x)  #(batch_num, num_patches, embed_dim)
        B, N, C = x.size()
        # Zlocal = MLP(x)
        local_x = x[:,:, :C//2]  #(batch_num, num_patches, embed_dim / 2)
        # Zglobal = Agg(MLP(x), D)
        global_x = (x[:,:, C//2:] * policy).sum(dim=1, keepdim=True) / torch.sum(policy, dim=1, keepdim=True)   #(batch_num, 1, embed_dim / 2)
        # Zi = [Zi(local), Zi(global)]
        x = torch.cat([local_x, global_x.expand(B, N, C//2)], dim=-1)   # (batch_num, num_patches, embed_dim)
        # π = Softmax(MLP(z))
        return self.out_conv(x)   # (batch_num, num_patches ,2)

最后,我们利用Hadamard product对决策矩阵进行更新。

决策矩阵更新

# use Hadamard product to update decistion matrix D
hard_keep_decision = F.gumbel_softmax(pred_score, hard=True)[:, :, 0:1] * prev_decision  # (batch_num, num_patches, 1)

End-to-end Optimization with Attention Masking

This prediction module can be optimized jointly in an end-to-end manner together with the vision transformer backbone. To this end, two specialized strategies are adopted.The first one is to adopt Gumbel-Softmax to overcome the non-differentiable problem of sampling from a distribution so that it is possible to perform the end-to-end training. The second one is about how to apply this learned binary decision mask to prune the unnecessary tokens.

在通过对π进行采样以得到决策矩阵D时,为了解决不可微问题以便于进行端到端的训练,作者采用Gumbel-Softmax进行采样。

GumberSoftmax

# x[:, 0]=cls_tokens
spatial_x = x[:, 1:]  # (batch_num, patch_num, embed_dim)
pred_score = self.score_predictor[p_count](spatial_x, prev_decision).reshape(B, -1, 2)  # (batch_num, num_patches, 2)
if self.training:
    # use Hadamard product to update decistion matrix D
    hard_keep_decision = F.gumbel_softmax(pred_score, hard=True)[:, :, 0:1] * prev_decision  # (batch_num, num_patches, 1)
    # store decision
    out_pred_prob.append(hard_keep_decision.reshape(B, init_n)) # (batch_num, num_patches)
    cls_policy = torch.ones(B, 1, 1, dtype=hard_keep_decision.dtype, device=hard_keep_decision.device)
    policy = torch.cat([cls_policy, hard_keep_decision], dim=1)  # (batch_num, num_patches+1 ,1)
    x = blk(x, policy=policy)  # (batch_num, num_patches+1 , embed_dim)
    prev_decision = hard_keep_decision  # (batch_num, num_patches, 1)
else:
    score = pred_score[:,:,0]  # (batch_num, num_keep_node, 1)
    num_keep_node = int(init_n * self.token_ratio[p_count]) 
    keep_policy = torch.argsort(score, dim=1, descending=True)[:, :num_keep_node]  # (batch_num, num_keep_node)
    cls_policy = torch.zeros(B, 1, dtype=keep_policy.dtype, device=keep_policy.device)  # (batch_num, 1)
    now_policy = torch.cat([cls_policy, keep_policy + 1], dim=1)  # (batch_num, num_keep_node+1)
    x = batch_index_select(x, now_policy)  # (batch_num, num_keep_node+1, embed_dim)
    prev_decision = batch_index_select(prev_decision, keep_policy)  # (batch_num, num_keep_node, embed_dim)
    x = blk(x)  # (batch_num, num_patches+1 , embed_dim)

而在实际的剪枝训练中,如果直接丢弃Di=0token,将导致batch中的不同样本数量不一致。而如果仅仅将token0,并无法消去其在self-attention中的影响。

自注意力机制
因此,作者采取一种精妙的策略能在保持样本token数不变的前提消除droptoken的影响。

在这里插入图片描述

def softmax_with_policy(self, attn, policy, eps=1e-6):
	# policy = prev_decision
	B, N, _ = policy.size()  # (batch_num, num_patches+1, 1)
	# attn = (q @ k.transpose(-2, -1)) * self.scale 
    B, H, N, N = attn.size()  # (batch_num, num_heads, num_patches+1, num_patches+1)
    attn_policy = policy.reshape(B, 1, 1, N)  # (batch_num, 1, 1, num_patches+1)
    eye = torch.eye(N, dtype=attn_policy.dtype, device=attn_policy.device).view(1, 1, N, N)  # (1, 1, num_patches+1, num_patches+1)
    
    # attain matrix G
    attn_policy = attn_policy + (1.0 - attn_policy) * eye  # (batch_num, 1, num_patches+1, num_patches+1)
    max_att = torch.max(attn, dim=-1, keepdim=True)[0]  # (batch_num, num_heads, num_patches+1, 1)
    # for stable training, Preventing overflow
    # softmax(x+c) = softmax(x)
    attn = attn - max_att  # (batch_num, num_heads, num_patches+1, num_patches+1)
    
    # for stable training to attain matrix A
    attn = attn.to(torch.float32).exp_() * attn_policy.to(torch.float32)  # (batch_num, self.num_heads, num_patches+1, num_patches+1)
    attn = (attn + eps/N) / (attn.sum(dim=-1, keepdim=True) + eps)  # (batch_num, self.num_heads, num_patches+1, num_patches+1)
    return attn.type_as(max_att)

Training and Inference

在实际的训练过程,loss function由多个方面构成。
首先,利用standard cross-entropy loss来计算模型预测与实际结果之间的损失。

损失函数一
其实,作者希望最小化因稀疏化模块引起的原模型的改变,因此作者将原始网络作为teacher mode,引入了两个特别的损失函数。
第一个损失函数用于描述稀疏化后的token序列与原序列的差异程度。

损失函数二
第二个损失函数用于描述稀疏化模块导致的预测结果的差异程序。

损失函数三
最后一个损失函数用于描述实际剪枝率与预期剪枝率的相差程序。

损失函数四
整个训练的损失函数由上述四个方面构成。

总损失函数
如果给定剪枝率ρ,可以利用

剪枝排序
规划每一轮剪枝中保留信息的token数量为

有效数量

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值