Arxiv 2106 - CAT: Cross Attention in Vision Transformer
- 论文:https://arxiv.org/abs/2106.05786
- 代码:https://github.com/linhezheng19/CAT
- 详细解读:https://mp.weixin.qq.com/s/VJCDAo94Uo_OtflSHRc1AQ
- 核心动机:使用patch内部和patch之间attention简化了全局attention计算。
本文仅做核心模块的粗略说明,力求对本文工作核心差异的完整展示,具体细节可见参考上面的解读文章。
主要内容
- Cross Attention Block (CAB) = Inner-Patch Self-Attention Block (IPSA) + Cross-Patch Self-Attention Block (CPSA):
- IPSA:就是标准的基于patch的attention,即attention的输入为
B*nph*npw,ph*pw,C
大小的tensor,得到的是空间大小为ph*pw,ph*pw
的attention矩阵。该模块建模了patch内部的全局关系。 - CPSA:这里处理的方式和以往的改进不太一样。这里attention计算的输入为
B*C,nph*npw,ph*pw
。对应的attention矩阵大小为nph*npw,nph*npw
,这里计算过程中是吧每个patch内部单一通道上的空间维度作为了每个patch信息的表示,从而通过相似性计算将这一维度给吸收了。这一模块基于通道独立的操作设计,构建了全局patch之间轻量的信息交互形式。
- IPSA:就是标准的基于patch的attention,即attention的输入为
核心代码
x = x.view(B, H, W, C)
# partition
patches = partition(x, self.patch_size) # nP*B, patch_size, patch_size, C
patches = patches.view(-1, self.patch_size * self.patch_size, C) # nP*B, patch_size*patch_size, C
# IPSA or CPSA
if self.attn_type == "ipsa":
attn = self.attn(patches) # nP*B, patch_size*patch_size, C
elif self.attn_type == "cpsa":
patches = patches.view(B, (H // self.patch_size) * (W // self.patch_size), self.patch_size ** 2, C).permute(0, 3, 1, 2).contiguous()
patches = patches.view(-1, (H // self.patch_size) * (W // self.patch_size), self.patch_size ** 2) # nP*B*C, nP*nP, patch_size*patch_size
attn = self.attn(patches).view(B, C, (H // self.patch_size) * (W // self.patch_size), self.patch_size ** 2)
attn = attn.permute(0, 2, 3, 1).contiguous().view(-1, self.patch_size ** 2, C) # nP*B, patch_size*patch_size, C
else :
raise NotImplementedError(f"Unkown Attention type: {self.attn_type}")
# reverse opration of partition
attn = attn.view(-1, self.patch_size, self.patch_size, C)
x = reverse(attn, self.patch_size, H, W) # B H' W' C
x = x.view(B, H * W, C)