一、背景介绍
这篇论文是CVPR2023的一篇论文,主要工作是对于Vision Transformer的自注意力机制进行了魔改。我感觉这篇文章或许对我的工作有帮助,因此,今天精读一下。(侵权删)(非常欢迎来argue,指正我的错误)
论文下载地址:[2211.11167] Vision Transformer with Super Token Sampling (arxiv.org)
代码开源仓库:hhb072/STViT (github.com)
在论文中,魔改前的机制似乎被叫做vanilla。
二、论文的introduction

三、论文的Related Works
主要是经典论文ViT和Superpixel Algorithms。
四、论文的模型细节
4.1. Overall Architecture
这篇论文的主要工作(STT)如下:
先从convolution stem("stem"通常指的是模型的前几层,这些层通常执行一些初步的特征提取)提取特征开始,因为有几篇论文也用过,发现和最初始的transform的tokenization (vanilla tokenization)相比,效果更好。然后提取的tokens经过四个阶段的STT block进行层次表示提取,以减少tokens数量,最后能输出分类。
这个super token transformer (STT) blocks,由三个key modules组成:Convolutional Position Embedding (CPE), Super Token Attention (STA), and Convolutional Feed-Forward Network (ConvFFN)。是一个线性相连的结构:
4.2. Super Token Attention
class StokenAttention(nn.Module):
def __init__(self, dim, stoken_size, n_iter=1, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.n_iter = n_iter
self.stoken_size = stoken_size
self.scale = dim ** - 0.5
self.unfold = Unfold(3)
self.fold = Fold(3)
self.stoken_refine = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop)
def stoken_forward(self, x):
'''
x: (B, C, H, W)
'''
B, C, H0, W0 = x.shape
h, w = self.stoken_size
pad_l = pad_t = 0
pad_r = (w - W0 % w) % w
pad_b = (h - H0 % h) % h
if pad_r > 0 or pad_b > 0:
x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
_, _, H, W = x.shape
hh, ww = H//h, W//w
stoken_features = F.adaptive_avg_pool2d(x, (hh, ww)) # (B, C, hh, ww)
pixel_features = x.reshape(B, C, hh, h, ww, w).permute(0, 2, 4, 3, 5, 1).reshape(B, hh*ww, h*w, C)
with torch.no_grad():
for idx in range(self.n_iter):
stoken_features = self.unfold(stoken_features) # (B, C*9, hh*ww)
stoken_features = stoken_features.transpose(1, 2).reshape(B, hh*ww, C, 9)
affinity_matrix = pixel_features @ stoken_features * self.scale # (B, hh*ww, h*w, 9)
affinity_matrix = affinity_matrix.softmax(-1) # (B, hh*ww, h*w, 9)
affinity_matrix_sum = affinity_matrix.sum(2).transpose(1, 2).reshape(B, 9, hh, ww)
affinity_matrix_sum = self.fold(affinity_matrix_sum)
if idx < self.n_iter - 1:
stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix # (B, hh*ww, C, 9)
stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B*C, 9, hh, ww)).reshape(B, C, hh, ww)
stoken_features = stoken_features/(affinity_matrix_sum + 1e-12) # (B, C, hh, ww)
stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix # (B, hh*ww, C, 9)
stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B*C, 9, hh, ww)).reshape(B, C, hh, ww)
stoken_features = stoken_features/(affinity_matrix_sum.detach() + 1e-12) # (B, C, hh, ww)
stoken_features = self.stoken_refine(stoken_features)
stoken_features = self.unfold(stoken_features) # (B, C*9, hh*ww)
stoken_features = stoken_features.transpose(1, 2).reshape(B, hh*ww, C, 9) # (B, hh*ww, C, 9)
pixel_features = stoken_features @ affinity_matrix.transpose(-1, -2) # (B, hh*ww, C, h*w)
pixel_features = pixel_features.reshape(B, hh, ww, C, h, w).permute(0, 3, 1, 4, 2, 5).reshape(B, C, H, W)
if pad_r > 0 or pad_b > 0:
pixel_features = pixel_features[:, :, :H0, :W0]
return pixel_features
def direct_forward(self, x):
B, C, H, W = x.shape
stoken_features = x
stoken_features = self.stoken_refine(stoken_features)
return stoken_features
def forward(self, x):
if self.stoken_size[0] > 1 or self.stoken_size[1] > 1:
return self.stoken_forward(x)
else:
return self.direct_forward(x)
4.2.1.Super Token Sampling (STS)
这一步是借鉴superpixels的一篇论文,通过稀疏关联学习从视觉标记中采样超标记。在他的代码中,这一步在 stoken_forward
函数中实现,首先,函数通过 F.pad
和 F.adaptive_avg_pool2d
将输入图像分解为一组超标记(stoken)。然后,它通过 unfold
函数将每个超标记进一步分解为一组像素特征。这些像素特征然后被用于计算每个超标记的特征。这就是为什么它被叫做super,因为他不像vit和swin是一整块的特征,是一组像素特征
4.2.2.Token Upsampling (TU)
将超标记映射回原始的标记空间。这一步在 stoken_forward
函数的最后部分实现。函数首先通过 fold
函数将处理过的超标记特征重新组合成原始图像的大小。然后,它通过 direct_forward
函数将这些特征映射回原始的像素空间。
4.2.3.复杂性分析
该工作相当于是将计算开销更大的 N × N 注意力矩阵 A(X) 分解为稀疏矩阵和小矩阵的乘法,即稀疏关联Q和m × m 注意力矩阵 A(S)。首先先给出对Global Self-Attention的复杂性的分析:
再对STA的复杂性进行了分析:
很明显,这个m是要小于N的,这个复杂性是有所降低的!
5.实验
文章把这个结构应用在了不同的任务上,发现均有不错的提升,得出结论,STA yes!