论文阅读记录(一)——Transformer in Transformer

论文阅读记录(一)——Transformer in Transformer

    这个项目是用来对大四的论文阅读做一下记录,希望可以在大四的这段时间尽快的融入未来的科研工作,希望可以和大家一起进步。

相对应的gihub项目为
争取一周3-4次更行

这里先介绍几个很好用的学习项目

PyTorch example

PyTorch的官方模板平台,学习PyTorch进阶操作最好的样例

如何读论文——论文精读

李沐大神在B站分享的项目,持续跟进可以很好规范自己的科研习惯~

TensorFlow2教程

B站上学习TensorFlow2的一个很好的教程,毕竟我们学习深度学习框架不能一直停留在使用python train.py的程度,分布式、静态图加速训练等都很有助于我们平时提升训练效率

timm

非常非常好用的一个分类模型库,作者对目前大部分SOTA的ImageNet分类模型进行了整合,并具有复现ImageNet训练结果的能力。目前比如非常优秀的模型SwinTransformer等都是基于timm库实现的。

Transformer in Transformer

    文章引言就非常有意思,不知道是不是上传的时候就是已经注明有MindSpore的框架的(如果是那么一看就知道是诺亚方舟实验室的作品了),具体的代码可以参考码云Github

知乎大佬的论文解读

    因为已经有大佬在知乎中对论文进行了详细的解读,因此就简单的结合代码对论文进行一个对应


# 首先不得不佩服Ross Wightman的timm库,助力极简的算法开发.这里我们先看一下部分的源码实现,以便在后面可以更好的阅读论文的相关内容

    ...
    self.outer_tokens = nn.Parameter(torch.zeros(1, num_patches, outer_dim), requires_grad=False)
    self.outer_pos = nn.Parameter(torch.zeros(1, num_patches + 1, outer_dim))
    self.inner_pos = nn.Parameter(torch.zeros(1, num_words, inner_dim))    
    ...
# 第一个重点,作者在这里使用了内部和外部的位置编码
class Block(nn.Module):
    """ TNT Block
    """
    ...
        self.has_inner = inner_dim > 0
        if self.has_inner:
            # Inner
# 第二个重点,作者用inner_dim参数控制了Block模块,这里也就是TNT的实现过程,主要看两个地方就可以,其他都是基本的imagenet train的pipline

Transformer in Transformer的结构图,主要可以关注加法的地方

    其实所谓的inner_block就是对图片以patch的形式进行同样的vit的流程

    比如输入是
[ B a t c h S i z e , 3 , 224 , 224 ] [BatchSize, 3, 224, 224] [BatchSize,3,224,224]
可以得到其 p a t c h _ s i z e = 16 patch\_size=16 patch_size=16的张量为
[ B a t c h S i z e , 224 16 × 224 16 , 3 × 16 × 16 ] [BatchSize, \frac{224}{16}\times\frac{224}{16},3\times16\times16] [BatchSize,16224×16224,3×16×16]

    前者为outer_token,后者为inner_token,对两者进行分别的vit运算就是所谓的Transformer in Transformer

        inner_tokens = self.patch_embed(x) + self.inner_pos  # B*N, 8*8, C
        # print("self.inner_pos", self.inner_pos.shape)  # ([1, 16, 40])
        # print("inner_tokens", inner_tokens.shape)

        outer_tokens = self.proj_norm2(self.proj(self.proj_norm1(inner_tokens.reshape(B, self.num_patches, -1))))
        outer_tokens = torch.cat((self.cls_token.expand(B, -1, -1), outer_tokens), dim=1)

        outer_tokens = outer_tokens + self.outer_pos
        outer_tokens = self.pos_drop(outer_tokens)

        for blk in self.blocks:
            inner_tokens, outer_tokens = blk(inner_tokens, outer_tokens)
'''
    在Block函数中,作者就对两个token进行了分别的定义运算
    具体可以看到inner_token和outer_token之间存在一个加法的交互
    Block通过has_inner操作(其实全部的block都是has_inner=True的)
'''

    核心的代码基本就是这两个地方,其他的东西可以结合代码具体debug一下就可以看懂运行的流程了。

参考资料

  1. Transformer in Transformer论文解读
  2. timm
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值