论文阅读记录(一)——Transformer in Transformer
这个项目是用来对大四的论文阅读做一下记录,希望可以在大四的这段时间尽快的融入未来的科研工作,希望可以和大家一起进步。
相对应的gihub项目为
争取一周3-4次更行
这里先介绍几个很好用的学习项目
PyTorch的官方模板平台,学习PyTorch进阶操作最好的样例
李沐大神在B站分享的项目,持续跟进可以很好规范自己的科研习惯~
B站上学习TensorFlow2的一个很好的教程,毕竟我们学习深度学习框架不能一直停留在使用python train.py的程度,分布式、静态图加速训练等都很有助于我们平时提升训练效率
非常非常好用的一个分类模型库,作者对目前大部分SOTA的ImageNet分类模型进行了整合,并具有复现ImageNet训练结果的能力。目前比如非常优秀的模型SwinTransformer等都是基于timm库实现的。
Transformer in Transformer
文章引言就非常有意思,不知道是不是上传的时候就是已经注明有MindSpore的框架的(如果是那么一看就知道是诺亚方舟实验室的作品了),具体的代码可以参考码云和Github
知乎大佬的论文解读
因为已经有大佬在知乎中对论文进行了详细的解读,因此就简单的结合代码对论文进行一个对应
# 首先不得不佩服Ross Wightman的timm库,助力极简的算法开发.这里我们先看一下部分的源码实现,以便在后面可以更好的阅读论文的相关内容
...
self.outer_tokens = nn.Parameter(torch.zeros(1, num_patches, outer_dim), requires_grad=False)
self.outer_pos = nn.Parameter(torch.zeros(1, num_patches + 1, outer_dim))
self.inner_pos = nn.Parameter(torch.zeros(1, num_words, inner_dim))
...
# 第一个重点,作者在这里使用了内部和外部的位置编码
class Block(nn.Module):
""" TNT Block
"""
...
self.has_inner = inner_dim > 0
if self.has_inner:
# Inner
# 第二个重点,作者用inner_dim参数控制了Block模块,这里也就是TNT的实现过程,主要看两个地方就可以,其他都是基本的imagenet train的pipline
其实所谓的inner_block就是对图片以patch的形式进行同样的vit的流程
比如输入是
[
B
a
t
c
h
S
i
z
e
,
3
,
224
,
224
]
[BatchSize, 3, 224, 224]
[BatchSize,3,224,224]
可以得到其
p
a
t
c
h
_
s
i
z
e
=
16
patch\_size=16
patch_size=16的张量为
[
B
a
t
c
h
S
i
z
e
,
224
16
×
224
16
,
3
×
16
×
16
]
[BatchSize, \frac{224}{16}\times\frac{224}{16},3\times16\times16]
[BatchSize,16224×16224,3×16×16]
前者为outer_token,后者为inner_token,对两者进行分别的vit运算就是所谓的Transformer in Transformer
inner_tokens = self.patch_embed(x) + self.inner_pos # B*N, 8*8, C
# print("self.inner_pos", self.inner_pos.shape) # ([1, 16, 40])
# print("inner_tokens", inner_tokens.shape)
outer_tokens = self.proj_norm2(self.proj(self.proj_norm1(inner_tokens.reshape(B, self.num_patches, -1))))
outer_tokens = torch.cat((self.cls_token.expand(B, -1, -1), outer_tokens), dim=1)
outer_tokens = outer_tokens + self.outer_pos
outer_tokens = self.pos_drop(outer_tokens)
for blk in self.blocks:
inner_tokens, outer_tokens = blk(inner_tokens, outer_tokens)
'''
在Block函数中,作者就对两个token进行了分别的定义运算
具体可以看到inner_token和outer_token之间存在一个加法的交互
Block通过has_inner操作(其实全部的block都是has_inner=True的)
'''
核心的代码基本就是这两个地方,其他的东西可以结合代码具体debug一下就可以看懂运行的流程了。