快读
-
简述:配备动态令牌稀疏化框架,与ImageNet上最先进的CNN和视觉转换器相比,DynamicViT模型可以实现非常有竞争力的复杂性/准确性权衡。
-
出发点:我们观察到视觉转换器中的最终预测仅基于大多数信息令牌的子集,这足以进行准确的图像识别。
-
方法:基于这一观察结果,我们提出了一个动态令牌稀疏化框架,以根据输入逐步和动态地修剪冗余令牌。
-
方法:具体来说,我们设计了一个轻量级的预测模块,以估计给定当前特征的每个令牌的重要性分数。该模块被添加到不同的层,以分层修剪冗余令牌。
-
方法:为了以端到端的方式优化预测模块,我们提出了一种注意力掩蔽策略,通过阻止令牌与其他令牌的交互来以不同的方式修剪令牌。
-
效果:受益于自我关注的本质,非结构化的稀疏令牌仍然是硬件友好的,这使得我们的框架很容易实现实际的加速。通过分层修剪66%的输入标记,我们的方法大大减少了31%~37%的FLOP,并将运行速度提高了40%以上,而各种视觉转换器的精度下降幅度在0.5%以内。
1. 卷积的下采样与DynamicViT
2.prediction模块是比较轻量化的MLP网络
class PredictorLG(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, embed_dim=384):
super().__init__()
# local建模
self.in_conv = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, embed_dim),
nn.GELU()
)
# π = Softmax(MLP(z))中的MLP,为每个token预测一个两维的向量,提供给softmax和Gumbel Softmax,来获取当前token是否需要被mask掉
self.out_conv = nn.Sequential(
nn.Linear(embed_dim, embed_dim // 2),
nn.GELU(),
nn.Linear(embed_dim // 2, embed_dim // 4),
nn.GELU(),
nn.Linear(embed_dim // 4, 2),
nn.LogSoftmax(dim=-1)
)
def forward(self, x, policy):
# x 表示当前输入的tokens
# policy表示当前的mask,由0和1组成,0表示不需要参与后序计算的token
x = self.in_conv(x) # 对于输入的每一个token先经过一层linear projection对局部信息进行建模
B, N, C = x.size()
local_x = x[:,:, :C//2]
# 在计算全局向量的时候,只对参与后序计算的token进行全局池化操作
global_x = (x[:,:, C//2:] * policy).sum(dim=1, keepdim=True) / torch.sum(policy, dim=1, keepdim=True)
# 将全局向量与局部向量拼接
x = torch.cat([local_x, global_x.expand(B, N, C//2)], dim=-1)
# 通过简单的MLP来输出每个token是否需要保留的一个分数,一组score
return self.out_conv(x)
符号 | 表示 |
---|---|
D ˆ ∈ { 0 , 1 } N Dˆ ∈ \{0, 1\}^N Dˆ∈{0,1}N | 代码中的policy |
x ∈ R N × C x\in R^{N\times C} x∈RN×C | Token |
z l o c a l = M L P ( x ) ∈ R N × C ′ z^{local} = MLP(x) ∈ R^{N×C'} zlocal=MLP(x)∈RN×C′ | local的信息 |
z g l o b a l = A G G { M L P ( x ) , D } ∈ R C ′ z^{global} = AGG\{MLP(x),D\} ∈ R^{C'} zglobal=AGG{MLP(x),D}∈RC′ | global的信息 |
z i = [ z i l o c a l , z i g l o b a l ] , 1 ≤ i ≤ N z_i = [z^{local}_i, z^{global}_i], 1 ≤ i ≤ N zi=[zilocal,ziglobal],1≤i≤N | |
π = S o f t m a x ( M L P ( z ) ) ∈ R N × 2 π = Softmax(MLP(z)) ∈ R^{N×2} π=Softmax(MLP(z))∈RN×2 | where πi,0 denotes the probability of dropping the i-th token and πi,1 is the probability of keeping it. |
D = G u m b e l − S o f t m a x ( π ) ∗ , 1 ∈ { 0 , 1 } N , D = Gumbel-Softmax(π)_{∗,1} ∈ \{0, 1\}^N , D=Gumbel−Softmax(π)∗,1∈{0,1}N, | Gumbel-Softmax 二值化操作,Gumbel-Softmax操作能让整个二值化的过程是可导的 |
class VisionTransformerDiffPruning(nn.Module):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
https://arxiv.org/abs/2010.11929
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None,
pruning_loc=None, token_ratio=None, distill=False):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
norm_layer: (nn.Module): normalization layer
"""
super().__init__()
print('## diff vit pruning method')
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# Representation layer
if representation_size:
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(embed_dim, representation_size)),
('act', nn.Tanh())
]))
else:
self.pre_logits = nn.Identity()
# Classifier head
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
predictor_list = [PredictorLG(embed_dim) for _ in range(len(pruning_loc))]
self.score_predictor = nn.ModuleList(predictor_list)
self.distill = distill
self.pruning_loc = pruning_loc
self.token_ratio = token_ratio
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
p_count = 0
out_pred_prob = []
init_n = 14 * 14
prev_decision = torch.ones(B, init_n, 1, dtype=x.dtype, device=x.device)
policy = torch.ones(B, init_n + 1, 1, dtype=x.dtype, device=x.device)
for i, blk in enumerate(self.blocks):
if i in self.pruning_loc:
spatial_x = x[:, 1:]
pred_score = self.score_predictor[p_count](spatial_x, prev_decision).reshape(B, -1, 2)
if self.training:
hard_keep_decision = F.gumbel_softmax(pred_score, hard=True)[:, :, 0:1] * prev_decision
out_pred_prob.append(hard_keep_decision.reshape(B, init_n))
cls_policy = torch.ones(B, 1, 1, dtype=hard_keep_decision.dtype, device=hard_keep_decision.device)
policy = torch.cat([cls_policy, hard_keep_decision], dim=1)
x = blk(x, policy=policy)
prev_decision = hard_keep_decision
else:
score = pred_score[:,:,0]
num_keep_node = int(init_n * self.token_ratio[p_count])
keep_policy = torch.argsort(score, dim=1, descending=True)[:, :num_keep_node]
cls_policy = torch.zeros(B, 1, dtype=keep_policy.dtype, device=keep_policy.device)
now_policy = torch.cat([cls_policy, keep_policy + 1], dim=1)
x = batch_index_select(x, now_policy)
prev_decision = batch_index_select(prev_decision, keep_policy)
x = blk(x)
p_count += 1
else:
if self.training:
x = blk(x, policy)
else:
x = blk(x)
x = self.norm(x)
features = x[:, 1:]
x = x[:, 0]
x = self.pre_logits(x)
x = self.head(x)
if self.training:
if self.distill:
return x, features, prev_decision.detach(), out_pred_prob
else:
return x, out_pred_prob
else:
return x
Gumbel-Softmax
Gumbel-Softmax paper
pytorch realization
- hard mode
- soft mode
τ 温度稀疏, g a m b l e s o f t m a x 可以逼近 π 的分布 \tau 温度稀疏,gamble softmax 可以逼近\pi的分布 τ温度稀疏,gamblesoftmax可以逼近π的分布
logits = torch.randn(20, 32)
# Sample soft categorical using reparametrization trick:
F.gumbel_softmax(logits, tau=1, hard=False)
# Sample hard categorical using "Straight-through" trick:
F.gumbel_softmax(logits, tau=1, hard=True)
import torch
import torch.nn.functional as F
hard_keep_decision = F.gumbel_softmax(torch.rand(10,2), hard=True)
print(hard_keep_decision)
x tensor([[0.9683, 0.4070],
[0.9941, 0.7483],
[0.1180, 0.9593],
[0.0570, 0.7179],
[0.3562, 0.1069],
[0.9350, 0.7042],
[0.4586, 0.9105],
[0.4754, 0.8280],
[0.7535, 0.4763],
[0.6013, 0.9938]])
Gumbel-Softmax
hard模式为0或1,非hard模式为小数
tensor([[1., 0.],
[0., 1.],
[0., 1.],
[0., 1.],
[0., 1.],
[1., 0.],
[1., 0.],
[0., 1.],
[0., 1.],
[0., 1.]])
- 换为10,6的结果
tensor([[0., 0., 0., 0., 1., 0.],
[0., 0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 1., 0.],
[0., 1., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0.],
[0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1., 0.],
[0., 0., 1., 0., 0., 0.]])
参考与更多
添加链接描述
添加链接描述
【NIPS 2021】DynamicViT:动态Token稀疏化ViT 论文解读
softmin
https://pytorch.org/docs/stable/generated/torch.nn.Softmax2d.html?highlight=softmax2d
Gumbel-softmax 中文解读
耿贝尔分布
http://amid.fish/humble-gumbel