DynamicViT:动态Token稀疏化ViT + Gumbel-Softmax

快读

  • 简述:配备动态令牌稀疏化框架,与ImageNet上最先进的CNN和视觉转换器相比,DynamicViT模型可以实现非常有竞争力的复杂性/准确性权衡。

  • 出发点:我们观察到视觉转换器中的最终预测仅基于大多数信息令牌的子集,这足以进行准确的图像识别。

  • 方法:基于这一观察结果,我们提出了一个动态令牌稀疏化框架,以根据输入逐步和动态地修剪冗余令牌。

  • 方法:具体来说,我们设计了一个轻量级的预测模块,以估计给定当前特征的每个令牌的重要性分数。该模块被添加到不同的层,以分层修剪冗余令牌。

  • 方法:为了以端到端的方式优化预测模块,我们提出了一种注意力掩蔽策略,通过阻止令牌与其他令牌的交互来以不同的方式修剪令牌。

  • 效果:受益于自我关注的本质,非结构化的稀疏令牌仍然是硬件友好的,这使得我们的框架很容易实现实际的加速。通过分层修剪66%的输入标记,我们的方法大大减少了31%~37%的FLOP,并将运行速度提高了40%以上,而各种视觉转换器的精度下降幅度在0.5%以内。

1. 卷积的下采样与DynamicViT

在这里插入图片描述

2.prediction模块是比较轻量化的MLP网络

在这里插入图片描述

class PredictorLG(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, embed_dim=384):
        super().__init__()
        # local建模
        self.in_conv = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim),
            nn.GELU()
        )
        # π = Softmax(MLP(z))中的MLP,为每个token预测一个两维的向量,提供给softmax和Gumbel Softmax,来获取当前token是否需要被mask掉
        self.out_conv = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.GELU(),
            nn.Linear(embed_dim // 2, embed_dim // 4),
            nn.GELU(),
            nn.Linear(embed_dim // 4, 2),
            nn.LogSoftmax(dim=-1)
        )

    def forward(self, x, policy):
        # x 表示当前输入的tokens
        # policy表示当前的mask,由01组成,0表示不需要参与后序计算的token
        x = self.in_conv(x)  # 对于输入的每一个token先经过一层linear projection对局部信息进行建模
        B, N, C = x.size()
        local_x = x[:,:, :C//2]
        # 在计算全局向量的时候,只对参与后序计算的token进行全局池化操作
        global_x = (x[:,:, C//2:] * policy).sum(dim=1, keepdim=True) / torch.sum(policy, dim=1, keepdim=True)
        # 将全局向量与局部向量拼接
        x = torch.cat([local_x, global_x.expand(B, N, C//2)], dim=-1)
        # 通过简单的MLP来输出每个token是否需要保留的一个分数,一组score
        return self.out_conv(x)
符号表示
D ˆ ∈ { 0 , 1 } N Dˆ ∈ \{0, 1\}^N Dˆ{0,1}N代码中的policy
x ∈ R N × C x\in R^{N\times C} xRN×CToken
z l o c a l = M L P ( x ) ∈ R N × C ′ z^{local} = MLP(x) ∈ R^{N×C'} zlocal=MLP(x)RN×Clocal的信息
z g l o b a l = A G G { M L P ( x ) , D } ∈ R C ′ z^{global} = AGG\{MLP(x),D\} ∈ R^{C'} zglobal=AGG{MLP(x),D}RCglobal的信息
z i = [ z i l o c a l , z i g l o b a l ] , 1 ≤ i ≤ N z_i = [z^{local}_i, z^{global}_i], 1 ≤ i ≤ N zi=[zilocal,ziglobal],1iN
π = S o f t m a x ( M L P ( z ) ) ∈ R N × 2 π = Softmax(MLP(z)) ∈ R^{N×2} π=Softmax(MLP(z))RN×2where πi,0 denotes the probability of dropping the i-th token and πi,1 is the probability of keeping it.
D = G u m b e l − S o f t m a x ( π ) ∗ , 1 ∈ { 0 , 1 } N , D = Gumbel-Softmax(π)_{∗,1} ∈ \{0, 1\}^N , D=GumbelSoftmax(π),1{0,1}N,Gumbel-Softmax 二值化操作,Gumbel-Softmax操作能让整个二值化的过程是可导的
class VisionTransformerDiffPruning(nn.Module):
    """ Vision Transformer

    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`  -
        https://arxiv.org/abs/2010.11929
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, 
                 pruning_loc=None, token_ratio=None, distill=False):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_chans (int): number of input channels
            num_classes (int): number of classes for classification head
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
            drop_rate (float): dropout rate
            attn_drop_rate (float): attention dropout rate
            drop_path_rate (float): stochastic depth rate
            hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
            norm_layer: (nn.Module): normalization layer
        """
        super().__init__()

        print('## diff vit pruning method')
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)

        if hybrid_backbone is not None:
            self.patch_embed = HybridEmbed(
                hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
        else:
            self.patch_embed = PatchEmbed(
                img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)

        # Representation layer
        if representation_size:
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ('fc', nn.Linear(embed_dim, representation_size)),
                ('act', nn.Tanh())
            ]))
        else:
            self.pre_logits = nn.Identity()

        # Classifier head
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        predictor_list = [PredictorLG(embed_dim) for _ in range(len(pruning_loc))]

        self.score_predictor = nn.ModuleList(predictor_list)

        self.distill = distill

        self.pruning_loc = pruning_loc
        self.token_ratio = token_ratio

        trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        p_count = 0
        out_pred_prob = []
        init_n = 14 * 14
        prev_decision = torch.ones(B, init_n, 1, dtype=x.dtype, device=x.device)
        policy = torch.ones(B, init_n + 1, 1, dtype=x.dtype, device=x.device)
        for i, blk in enumerate(self.blocks):
            if i in self.pruning_loc:
                spatial_x = x[:, 1:]
                pred_score = self.score_predictor[p_count](spatial_x, prev_decision).reshape(B, -1, 2)
                if self.training:
                    hard_keep_decision = F.gumbel_softmax(pred_score, hard=True)[:, :, 0:1] * prev_decision
                    out_pred_prob.append(hard_keep_decision.reshape(B, init_n))
                    cls_policy = torch.ones(B, 1, 1, dtype=hard_keep_decision.dtype, device=hard_keep_decision.device)
                    policy = torch.cat([cls_policy, hard_keep_decision], dim=1)
                    x = blk(x, policy=policy)
                    prev_decision = hard_keep_decision
                else:
                    score = pred_score[:,:,0]
                    num_keep_node = int(init_n * self.token_ratio[p_count])
                    keep_policy = torch.argsort(score, dim=1, descending=True)[:, :num_keep_node]
                    cls_policy = torch.zeros(B, 1, dtype=keep_policy.dtype, device=keep_policy.device)
                    now_policy = torch.cat([cls_policy, keep_policy + 1], dim=1)
                    x = batch_index_select(x, now_policy)
                    prev_decision = batch_index_select(prev_decision, keep_policy)
                    x = blk(x)
                p_count += 1
            else:
                if self.training:
                    x = blk(x, policy)
                else:
                    x = blk(x)

        x = self.norm(x)
        features = x[:, 1:]
        x = x[:, 0]
        x = self.pre_logits(x)
        x = self.head(x)
        if self.training:
            if self.distill:
                return x, features, prev_decision.detach(), out_pred_prob
            else:
                return x, out_pred_prob
        else:
            return x

Gumbel-Softmax

Gumbel-Softmax paper
pytorch realization

在这里插入图片描述

  • hard mode
    在这里插入图片描述

在这里插入图片描述

  • soft mode
    在这里插入图片描述
    τ 温度稀疏, g a m b l e s o f t m a x 可以逼近 π 的分布 \tau 温度稀疏,gamble softmax 可以逼近\pi的分布 τ温度稀疏,gamblesoftmax可以逼近π的分布

在这里插入图片描述

logits = torch.randn(20, 32)
# Sample soft categorical using reparametrization trick:
F.gumbel_softmax(logits, tau=1, hard=False)
# Sample hard categorical using "Straight-through" trick:
F.gumbel_softmax(logits, tau=1, hard=True)
import torch
import torch.nn.functional as F
hard_keep_decision = F.gumbel_softmax(torch.rand(10,2), hard=True)
print(hard_keep_decision)
x tensor([[0.9683, 0.4070],
        [0.9941, 0.7483],
        [0.1180, 0.9593],
        [0.0570, 0.7179],
        [0.3562, 0.1069],
        [0.9350, 0.7042],
        [0.4586, 0.9105],
        [0.4754, 0.8280],
        [0.7535, 0.4763],
        [0.6013, 0.9938]])
Gumbel-Softmax 
hard模式为01,非hard模式为小数
tensor([[1., 0.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [0., 1.],
        [0., 1.]])
- 换为106的结果
tensor([[0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0., 0.]])

参考与更多

添加链接描述
添加链接描述
【NIPS 2021】DynamicViT:动态Token稀疏化ViT 论文解读
softmin
https://pytorch.org/docs/stable/generated/torch.nn.Softmax2d.html?highlight=softmax2d
Gumbel-softmax 中文解读
耿贝尔分布
http://amid.fish/humble-gumbel

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值