DeepViT、DeiT、CaiT、T2T、Cross-ViT、PiT、LeViT、CvT算法整理

改变函数不变网络结构


DeepVit(改变Attention Value函数)

发现

  1. 在原来的Vit模型中,随着transformer层的增多,结果并没有得到优化,甚至会出现下降(32层的差于24层的)
  2. 随着ViT模型深度的增加,注意力图在一定深层数的计算之后会趋于相似
  3. 这个现象被称作attention collapse

解决方法

计算value时,将各个head的attention weight进行交换(用一个可学习的矩阵实现)
原Attention Value计算公式:
A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T d ) V \mathrm{Attention}(Q,K,V)=\mathrm{Softmax}(\frac{QK^T}{\sqrt d})V Attention(Q,K,V)=Softmax(d QKT)V
新的Re-Attention Value计算公式:
R e − A t t e n t i o n ( Q , K , V ) = N o r m ( θ T ( S o f t m a x ( Q K T d ) ) ) V \mathrm{Re-Attention}(Q,K,V)=\mathrm{Norm}(\theta^T(\mathrm{Softmax}(\frac{QK^T}{\sqrt d})))V ReAttention(Q,K,V)=Norm(θT(Softmax(d QKT)))V

代码

  • Re-Attention类
class Re-Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.reattn_weights = nn.Parameter(torch.randn(heads, heads))
        # 待学习的转换矩阵参数

        self.reattn_norm = nn.Sequential(
            Rearrange('b h i j -> b i j h'),
            nn.LayerNorm(heads),
            Rearrange('b i j h -> b h i j')
        )
        # 新加入的Norm层

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        # Calculate attention weight

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        attn = dots.softmax(dim=-1)

        # re-attention

        attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights)
        attn = self.reattn_norm(attn)
        # 计算新的attention weight

        # aggregate and out

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

目的与优点

  • 目的是用来替代dropout这种过于简单的取舍
  • 优点为实现方式简单

DeiT(改变loss)

如何应用

  1. 将一个表现较好的模型作为teacher:经过实验,发现使用CNN作为teacher表现更好
  2. 将ViT模型作为student
  3. 在输入的尾部增加一个distillation token作为计算distillation的依据
  4. 将计算出的distillation的结果作为loss

创新点

提出2种distill方式(两个计算distillation的公式):

  • Soft distillation
    L g l o b a l = ( 1 − λ ) L C E ( S o f t m a x ( Z s ) , y ) + λ τ 2 K L ( S o f t m a x ( Z s / τ ) , S o f t m a x ( Z t / τ ) ) L_{global} = (1-\lambda)L_{CE}(\mathrm{Softmax}(Z_s),y)+\lambda\tau^2\mathrm{KL}(\mathrm{Softmax}(Z_s/\tau),\mathrm{Softmax}(Z_t/\tau)) Lglobal=(1λ)LCE(Softmax(Zs),y)+λτ2KL(Softmax(Zs/τ),Softmax(Zt/τ))

Z t Z_t Zt是teacher model的结果
Z s Z_s Zs是student model的结果
y y y是真实值
L C E L_{CE} LCE是cross entropy, K L \mathrm{KL} KL是Kullback-Leibler divergence
τ \tau τ是蒸馏温度

  • Hard-label distillation
    L g l o b a l h a r d D i s t i l l = 1 2 L C E ( S o f t m a x ( Z s ) , y ) + 1 2 L C E ( S o f t m a x ( Z s ) , a r g m a x ( Z t ) ) L_{global}^{\mathrm{hardDistill}} = \frac{1}{2}L_{CE}(\mathrm{Softmax}(Z_s),y)+\frac{1}{2}L_{CE}(\mathrm{Softmax}(Z_s),\mathrm{argmax}(Z_t)) LglobalhardDistill=21LCE(Softmax(Zs),y)+21LCE(Softmax(Zs),argmax(Zt))

代码体现

  • 将Distillation token加入输入
def exists(val):
    return val is not None
# classes
class DistillMixin: 
    def forward(self, img, distill_token = None):
        distilling = exists(distill_token)
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim = 1)
        x += self.pos_embedding[:, :(n + 1)]

        if distilling:
            distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
            x = torch.cat((x, distill_tokens), dim = 1)
            # 将Distillation token加到最后面

        x = self._attend(x)

        if distilling:
            x, distill_tokens = x[:, :-1], x[:, -1]

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        out = self.mlp_head(x)

        if distilling:
            return out, distill_tokens

        return out
  • 计算loss
class DistillWrapper(nn.Module):
    def __init__(
        self,
        *,
        teacher,
        student,
        temperature = 1.,
        alpha = 0.5,
        hard = False
    ):
        super().__init__()
        assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'

        self.teacher = teacher
        self.student = student

        dim = student.dim
        num_classes = student.num_classes
        self.temperature = temperature
        self.alpha = alpha
        self.hard = hard

        self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))

        self.distill_mlp = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img, labels, temperature = None, alpha = None, **kwargs):
        b, *_ = img.shape
        alpha = alpha if exists(alpha) else self.alpha
        T = temperature if exists(temperature) else self.temperature

        with torch.no_grad():
            teacher_logits = self.teacher(img)

        student_logits, distill_tokens = self.student(img, distill_token = self.distillation_token, **kwargs)
        distill_logits = self.distill_mlp(distill_tokens)

        loss = F.cross_entropy(student_logits, labels)

        if not self.hard:
            distill_loss = F.kl_div(
                F.log_softmax(distill_logits / T, dim = -1),
                F.softmax(teacher_logits / T, dim = -1).detach(),
            reduction = 'batchmean')
            distill_loss *= T ** 2

        else:
            teacher_labels = teacher_logits.argmax(dim = -1)
            distill_loss = F.cross_entropy(student_logits, teacher_labels)

        return loss * (1 - alpha) + distill_loss * alpha

改变网络结构

CaiT

1. 改变残差网络(LayerScale+Class-Attention)

结构比较
  • 原来的残差网络

原来的残差网络公式:
x l ′ = x l + S A ( x l ) x l + 1 = x l ′ + F F N ( x l ′ ) x_l'=x_l+\mathrm{SA}(x_l) \\x_{l+1}=x_l'+\mathrm{FFN}(x_l') xl=xl+SA(xl)xl+1=xl+FFN(xl)

  • 新的残差网络

新的残差网络公式:
x l ′ = x l + d i a g ( λ l , 1 . . . λ l , d ) S A ( x l ) x l + 1 = x l ′ + d i a g ( λ l , 1 ′ . . . λ l , d ′ ) F F N ( x l ′ ) x_l'=x_l+\mathrm{diag}(\lambda_{l,1}...\lambda_{l,d})\mathrm{SA}(x_l) \\x_{l+1}=x_l'+\mathrm{diag}(\lambda^{'}_{l,1}...\lambda^{'}_{l,d})\mathrm{FFN}(x_l') xl=xl+diag(λl,1...λl,d)SA(xl)xl+1=xl+diag(λl,1...λl,d)FFN(xl)
这里的权重 λ \lambda λ是一个初始值很小的可学习的参数

代码体现
  1. 增加了函数
class LayerScale(nn.Module):
    def __init__(self, dim, fn, depth):
        super().__init__()
        if depth <= 18:  # epsilon detailed in section 2 of paper
            init_eps = 0.1
        elif depth > 18 and depth <= 24:
            init_eps = 1e-5
        else:
            init_eps = 1e-6

        scale = torch.zeros(1, 1, dim).fill_(init_eps)
        self.scale = nn.Parameter(scale)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) * self.scale
  1. 修改了残差网络
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.layer_dropout = layer_dropout

        for ind in range(depth):
            self.layers.append(nn.ModuleList([
                LayerScale(dim, PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), depth = ind + 1),
                LayerScale(dim, PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)), depth = ind + 1)
            ]))
            # 这里原来没有LayerScale,就只有
            # PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)
            # PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)
意义

通过加入一个特别小的、可学习的参数 λ \lambda λ,可以有效降低每层计算后对原信息的变化,从而减缓网络层数变大时发生的过拟合现象

2. Class-Attention结构

结构比较

最初的ViT模型中,class token从一开始就跟着特征信息进入网络进行训练的

ViT中计算Attention Value时Q K V的获取公式:
Q = W q z + b q Q=W_qz+b_q Q=Wqz+bq
K = W k z + b k K=W_kz+b_k K=Wkz+bk
V = W v z + b v V=W_vz+b_v V=Wvz+bv
z = [ x c l a s s , x p a t c h e s ] z=[x_{class},x_{patches}] z=[xclass,xpatches]

而在CaiT中,将特征学习和类别学习分开,先只对特征进行学习,要结束时(demo中是12个self层,2个class层),才将class token代入,此时特征只进行注意力的计算,不通过前馈神经网络

self-attention中Q K V的获取公式:
Q = W q x p a t c h e s + b q Q=W_qx_{patches}+b_q Q=Wqxpatches+bq
K = W k x p a t c h e s + b k K=W_kx_{patches}+b_k K=Wkxpatches+bk
V = W v x p a t c h e s + b v V=W_vx_{patches}+b_v V=Wvxpatches+bv

class-attention中Q K V的获取公式:
Q = W q x c l a s s + b q Q=W_qx_{class}+b_q Q=Wqxclass+bq
K = W k z + b k K=W_kz+b_k K=Wkz+bk
V = W v z + b v V=W_vz+b_v V=Wvz+bv
z = [ x c l a s s , x p a t c h e s ] z=[x_{class},x_{patches}] z=[xclass,xpatches]

代码体现
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.layer_dropout = layer_dropout

        for ind in range(depth):
            self.layers.append(nn.ModuleList([
                LayerScale(dim, PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), depth = ind + 1),
                LayerScale(dim, PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)), depth = ind + 1)
            ]))
    def forward(self, x, context = None):
        layers = dropout_layers(self.layers, dropout = self.layer_dropout)

        for attn, ff in layers:
            x = attn(x, context = context) + x
            # 当context为None时:x是$x_patches$
            # 当context不为None时:context是$x_patches$, x是$x_class$
            x = ff(x) + x
        return x

class CaiT(nn.Module):
    def __init__(
        self, *, image_size, patch_size, num_classes, dim, depth, cls_depth, heads, mlp_dim, dim_head = 64, dropout = 0., emb_dropout = 0., layer_dropout = 0.):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size ** 2

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        self.dropout = nn.Dropout(emb_dropout)

        self.patch_transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
        self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
        # 有2种transformer

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        x += self.pos_embedding[:, :n]
        x = self.dropout(x)

        x = self.patch_transformer(x)
        # 先进行patch_transformer,此时未加入cla_token

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = self.cls_transformer(cls_tokens, context = x)
		# 加入cls_token后再进行cls_transformer,得到的就是cls_token
        return self.mlp_head(x[:, 0])

Token-to-Token ViT(Patch_embedding)

网络结构

1. 整体网络结构

分为2大块,可以简单归类为不带class_token的patch_embedding部分和带class_token的Transformer部分。带class_token部分的网络结构与ViT相同,变化较大的是前面新加的不带class_token部分的结构,作者将其命名为“Tokens-to-Token module”

2. Tokens-to-Token module

整个过程可以分为如下几步:

  1. 刚开始的时候会将 C C C通道的图片分为 ( h ⋅ w ) (h\cdot w) (hw)个patch,每个patch的大小为 p 1 × p 2 p_1\times p_2 p1×p2然后把每个patch里的数据拉直,从而得到 ( h ⋅ w ) (h\cdot w) (hw)个大小为 ( p 1 ⋅ p 2 ⋅ C ) (p_1\cdot p_2\cdot C) (p1p2C)的tokens
  2. 将这 ( h ⋅ w ) (h\cdot w) (hw)个大小为 n n n的tokens进行重组,变成 n n n个大小为 h × w h\times w h×w的 图片
  3. 对这 n n n个大小为 h × w h\times w h×w图片进行Unfold处理,假设kernal的大小为 [ k 1 , k 2 ] [k_1,k_2] [k1,k2],则经过Unfold之后,我们能得到 l l l个大小为 n ⋅ k 1 ⋅ k 2 n\cdot k_1 \cdot k_2 nk1k2的tokens

l = [ h + 2 p − k 1 k 1 − s + 1 ] × [ w + 2 p − k 2 k 2 − s ] l=[\frac{h+2p-k_1}{k_1-s}+1]\times [\frac{w+2p-k_2}{k_2-s}] l=[k1sh+2pk1+1]×[k2sw+2pk2]
p p p是padding的大小
s s s是stride的大小

  1. 我们希望保留不同区域之间的关系,所以我们对这些tokens进行reshape,得到 n ⋅ k 1 ⋅ k 2 n\cdot k_1 \cdot k_2 nk1k2个大小为 l l l的tokens
  2. 将每个token代入到Transformer中计算,这样我们能够将各领域间的信息共享
  3. 重复step 2

Unfold操作的本质就是另一种意义上的reshape,不会进行学习和计算,就是把滑块包含区域内的Tokens拼接成一个Token

代码实现

def conv_output_size(image_size, kernel_size, stride, padding):
    return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)

# classes

class RearrangeImage(nn.Module):
    def forward(self, x):
        return rearrange(x, 'b (h w) c -> b c h w', h = int(math.sqrt(x.shape[1])))

# main class

class T2TViT(nn.Module):
    def __init__(self, *, image_size, num_classes, dim, depth = None, heads = None, mlp_dim = None, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., transformer = None, t2t_layers = ((7, 4), (3, 2), (3, 2))):
        super().__init__()
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        layers = []
        layer_dim = channels
        output_image_size = image_size

        for i, (kernel_size, stride) in enumerate(t2t_layers):
            layer_dim *= kernel_size ** 2
            is_first = i == 0
            output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2)

            layers.extend([
                RearrangeImage() if not is_first else nn.Identity(),
                nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2),
                Rearrange('b c n -> b n c'),
                Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout),
            ])
            # 每一层layer都包含了4步:重组、Unfold、调序、计算

        layers.append(nn.Linear(layer_dim, dim))
        self.to_patch_embedding = nn.Sequential(*layers)
        # patch_embeddingcong从一开始的Rearrange变得复杂

        self.pos_embedding = nn.Parameter(torch.randn(1, output_image_size ** 2 + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        if not exists(transformer):
            assert all([exists(depth), exists(heads), exists(mlp_dim)]), 'depth, heads, and mlp_dim must be supplied'
            self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        else:
            self.transformer = transformer

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        # 只改变了这个patch_embedding
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

Cross-ViT(Multi-Scale Features)

整体结构

  1. 先将原图片按照大尺寸和小尺寸进行分割,这样就有了S-Batch(Small patch size Batch)和L-Batch(Large patch size Batch)
  2. 先各自进行Transformer计算
  3. 将得到的结果使用Cross-Attention方法进行计算
  4. 将不同patch size的检测结果相加就是该网络的给出的结果(代码里就是两个Batch的CLS token分别经过mlp_head函数之后,将两个得到的值相加)
def forward(self, img):
        sm_tokens = self.sm_image_embedder(img)
        lg_tokens = self.lg_image_embedder(img)

        sm_tokens, lg_tokens = self.multi_scale_encoder(sm_tokens, lg_tokens)

        sm_cls, lg_cls = map(lambda t: t[:, 0], (sm_tokens, lg_tokens))

        sm_logits = self.sm_mlp_head(sm_cls)
        lg_logits = self.lg_mlp_head(lg_cls)

        return sm_logits + lg_logits

Cross-Attention方法

  1. 取L-Batch的CLS token: x c l s l x^l_{cls} xclsl,S-Batch的patch_token: x p a t c h s x^s_{patch} xpatchs,经过计算可以得到最下面虚线框内的数据 x ′ l x^{'l} xl
    x ′ l = [ f l ( x c l s l ) ∣ ∣ x p a t c h s ] x^{'l}=[f^l(x^l_{cls})||x^s_{patch}] xl=[fl(xclsl)xpatchs]
  2. x c l s ′ l x^{'l}_{cls} xclsl计算Q,用 x ′ l x^{'l} xl计算K和V
    Q = x c l s ′ l W q K = x ′ l W k V = x ′ l W v Q=x^{'l}_{cls}W_q\qquad K=x^{'l}W_k\qquad V=x^{'l}W_v Q=xclslWqK=xlWkV=xlWv
  3. 计算Cross-Attention值 C A ( x ′ l ) \mathrm{CA}(x^{'l}) CA(xl)
    C A ( x ′ l ) = S o f t m a x ( Q K T / C / h ) V \mathrm{CA}(x^{'l})=\mathrm{Softmax}(QK^T/\sqrt{C/h})V CA(xl)=Softmax(QKT/C/h )V
  4. 计算新的CLS token: y c l s l y^l_{cls} yclsl
    y c l s l = g l ( f l ( x c l s l ) + C A ( x ′ l ) ) y^l_{cls}=g^l(f^l(x^l_{cls})+\mathrm{CA}(x^{'l})) yclsl=gl(fl(xclsl)+CA(xl))
  5. 拼接得到新的L-Batch
    z l = [ y c l s l ∣ ∣ x p a t c h l ] z^l=[y^l_{cls}||x^l_{patch}] zl=[yclslxpatchl]

f l ( ⋅ ) f^l(\cdot) fl() g l ( ⋅ ) g^l(\cdot) gl()都是针对L-Batch的线性变化

代码实现(Cross-Attention部分)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context = None, kv_include_self = False):
        b, n, _, h = *x.shape, self.heads
        context = default(context, x)

        if kv_include_self:
            context = torch.cat((x, context), dim = 1) # cross attention requires CLS token includes itself as key / value

        qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# projecting CLS tokens, in the case that small and large patch tokens have different dimensions

class ProjectInOut(nn.Module):
    def __init__(self, dim_in, dim_out, fn):
        super().__init__()
        self.fn = fn

        need_projection = dim_in != dim_out
        self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity()
        self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity()

    def forward(self, x, *args, **kwargs):
        x = self.project_in(x)
        x = self.fn(x, *args, **kwargs)
        x = self.project_out(x)
        return x

# cross attention transformer

class CrossTransformer(nn.Module):
    def __init__(self, sm_dim, lg_dim, depth, heads, dim_head, dropout):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                ProjectInOut(sm_dim, lg_dim, PreNorm(lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout))),
                ProjectInOut(lg_dim, sm_dim, PreNorm(sm_dim, Attention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout)))
            ]))

    def forward(self, sm_tokens, lg_tokens):
        (sm_cls, sm_patch_tokens), (lg_cls, lg_patch_tokens) = map(lambda t: (t[:, :1], t[:, 1:]), (sm_tokens, lg_tokens))

        for sm_attend_lg, lg_attend_sm in self.layers:
            sm_cls = sm_attend_lg(sm_cls, context = lg_patch_tokens, kv_include_self = True) + sm_cls
            lg_cls = lg_attend_sm(lg_cls, context = sm_patch_tokens, kv_include_self = True) + lg_cls

        sm_tokens = torch.cat((sm_cls, sm_patch_tokens), dim = 1)
        lg_tokens = torch.cat((lg_cls, lg_patch_tokens), dim = 1)
        return sm_tokens, lg_tokens

PiT(加Pooling层)

整体结构

  1. 先进行embedding,得到基础的tokens
  2. 经过若干次Transformer层进行计算
  3. 进入pooling层,patch_tokens进行下采样,cls_tokens通过全连接层使维度与patch_tokens的契合
  4. 重复步骤2,重复步骤3,重复步骤2
  5. 使用cls_tokens输出结果

Pooling Layer

  1. 由于pooling是针对二维Tensor的操作,所以先将patch_token改变成二维的
  2. 使用DepthWise操作进行下采样,减少维度的同时提高深度(这里我觉得可能是论文作者写错了)

DepthWise操作(Xception的结构)
假设需要进行操作的数据维度为 w × h × d w\times h\times d w×h×d

  1. 先用 d d d 3 × 3 3\times 3 3×3的卷积核,每个卷积核对应一个通道,步长为2,得到维度为 w 2 × h 2 × d \frac{w}{2}\times \frac{h}{2}\times d 2w×2h×d的数据
  2. 使用 2 d 2d 2d 1 × 1 × d 1\times 1\times d 1×1×d的卷积核,分别将得到的各通道数据加权求和,得到维度为 w 2 × h 2 × 2 d \frac{w}{2}\times \frac{h}{2}\times 2d 2w×2h×2d的数据

DepthWise操作(代码里的结构)
假设需要进行操作的数据维度为 w × h × d w\times h\times d w×h×d

  1. 先用 2 d 2d 2d 1 × 3 × 3 1 \times 3\times 3 1×3×3的卷积核,每2个卷积核对应一个通道,步长为2,得到维度为 w 2 × h 2 × 2 d \frac{w}{2}\times \frac{h}{2}\times 2d 2w×2h×2d的数据
  2. 使用 2 d 2d 2d 1 × 1 × 2 d 1\times 1\times 2d 1×1×2d的卷积核,分别将得到的各通道数据加权求和,得到维度为 w 2 × h 2 × 2 d \frac{w}{2}\times \frac{h}{2}\times 2d 2w×2h×2d的数据
  1. 将patch_token恢复成一维的以进行后面Transformer层的计算
  2. cls_token直接使用一个全连接层提高深度

代码实现

# depthwise convolution, for pooling

class DepthWiseConv2d(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
            nn.Conv2d(dim_out, dim_out, kernel_size = 1, bias = bias)
            # Xception的结构:
            # nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
            # nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
        )
    def forward(self, x):
        return self.net(x)

# pooling layer

class Pool(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.downsample = DepthWiseConv2d(dim, dim * 2, kernel_size = 3, stride = 2, padding = 1)
        self.cls_ff = nn.Linear(dim, dim * 2)

    def forward(self, x):
        cls_token, tokens = x[:, :1], x[:, 1:]

        cls_token = self.cls_ff(cls_token)

        tokens = rearrange(tokens, 'b (h w) c -> b c h w', h = int(sqrt(tokens.shape[1])))
        tokens = self.downsample(tokens)
        tokens = rearrange(tokens, 'b c h w -> b (h w) c')

        return torch.cat((cls_token, tokens), dim = 1)

# main class

class PiT(nn.Module):
    def __init__(
        self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, dim_head = 64, dropout = 0., emb_dropout = 0.
    ):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        assert isinstance(depth, tuple), 'depth must be a tuple of integers, specifying the number of blocks before each downsizing'
        heads = cast_tuple(heads, len(depth))

        patch_dim = 3 * patch_size ** 2

        self.to_patch_embedding = nn.Sequential(
            nn.Unfold(kernel_size = patch_size, stride = patch_size // 2),
            Rearrange('b c n -> b n c'),
            nn.Linear(patch_dim, dim)
        )

        output_size = conv_output_size(image_size, patch_size, patch_size // 2)
        num_patches = output_size ** 2

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        layers = []

        for ind, (layer_depth, layer_heads) in enumerate(zip(depth, heads)):
            not_last = ind < (len(depth) - 1)
            
            layers.append(Transformer(dim, layer_depth, layer_heads, dim_head, mlp_dim, dropout))

            if not_last:
                layers.append(Pool(dim))
                dim *= 2

        self.layers = nn.Sequential(*layers)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding
        x = self.dropout(x)

        x = self.layers(x)

        return self.mlp_head(x[:, 0])

LeViT

整体结构

我们可以清晰地看到,它分为了3个大块:

  • 卷积部分:这部分的作用在于提取原图中的特征并且降低单张图片的维度(文章验证了,在进行Transformer之前加入小的卷积网络可以提高准确率)
  • Transformer部分:这部分被分为了3个stage。可以看出,当不进行down sampling时,有残差网络;进行down sampling时,由于形状发生了变化,故没有残差结构。我们之后会着重分析这一部分
  • Classifier部分:在LeViT中,我们没有使用cls_token结构,而是简单的将最后的结果进行均值池化。之后如果有蒸馏操作,那么将池化后的结果分别代入到学生分类器和老师分类器中进行分类(在代码中是用两个torch.Linear()实现的);如果没有蒸馏操作则直接进行分类。
  • TIPS: 每次卷积后面都有初始权重为0的BatchNorm层

Attention Block

左边的是正常部分,右边的是Shrink Attention
D D D是Q、K的维度,在代码中被设置成了32
不进行down sampling时,V的维度为 2 D 2D 2D,进行down sampling时,V的维度为 4 D 4D 4D
N N N是head的数量,不进行down sampling时 N = C / 2 D N=C/2D N=C/2D;如果进行down sampling, N = C / D N=C/D N=C/D

不进行down sampling(正常的)

假设输入尺寸为 C × H × W C\times H\times W C×H×W

  1. 计算 Q 、 K 、 V Q、K、V QKV
    对原图进行point convolution,得到Q、K、V
    Q . s h a p e = N D × H × W Q\mathrm{.shape}=ND\times H\times W Q.shape=ND×H×W
    K . s h a p e = N D × H × W K\mathrm{.shape}=ND\times H\times W K.shape=ND×H×W
    V . s h a p e = N 2 D × H × W V\mathrm{.shape}=N2D\times H\times W V.shape=N2D×H×W
    然后将其整形
    Q . s h a p e = N × H W × D Q\mathrm{.shape}=N\times HW\times D Q.shape=N×HW×D
    K . s h a p e = N × H W × D K\mathrm{.shape}=N\times HW\times D K.shape=N×HW×D
    V . s h a p e = N × H W × 2 D V\mathrm{.shape}=N\times HW\times 2D V.shape=N×HW×2D
  2. 计算attention bias( B B B)
    用attention bias代替ViT中的pos_embedding,使得每一层都会代入位置信息(文中引用了他人的观点,论证加入位置信息后能提高Transformer的准确率)
    self.pos_bias = nn.Embedding(fmap_size * fmap_size, heads)
    # fmap_size = 14
    q_range = torch.arange(0, fmap_size, step = 1)
    k_range = torch.arange(fmap_size)
    
    q_pos = torch.stack(torch.meshgrid(q_range, q_range), dim = -1)
    k_pos = torch.stack(torch.meshgrid(k_range, k_range), dim = -1)
    # q_pos: torch.Size([14, 14, 2])
    # k_pos: torch.Size([14, 14, 2])
    
    q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
    # rearranged q_pos: torch.Size([196, 2])
    # rearranged k_pos: torch.Size([196, 2])
    
    rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()
     # rel_pos: torch.Size([196, 196, 2])
     
    x_rel, y_rel = rel_pos.unbind(dim = -1)
    # x_rel: torch.Size([196, 196])
    # y_rel: torch.Size([196, 196])
    
    pos_indices = (x_rel * fmap_size) + y_rel
    # pos_indices: torch.Size([196, 196])
    
    pos_indices的尺寸为 H W × H W HW\times HW HW×HW,将其扩成 N × H W × H W N\times HW\times HW N×HW×HW
  3. 计算Attention weight( A A A)
    A = S o f t m a x ( ( Q K T + B ) / D A=\mathrm{Softmax}((QK^T+B)/\sqrt{D} A=Softmax((QKT+B)/D
    A . s h a p e = N × H W × H W A\mathrm{.shape}=N\times HW\times HW A.shape=N×HW×HW
  4. 计算Attention Value
    A t t e n t i o n ( Q , K , V ) = A V \mathrm{Attention}(Q,K,V)=AV Attention(Q,K,V)=AV
    A t t e n t i o n ( Q , K , V ) . s h a p e = N × H W × 2 D \mathrm{Attention}(Q,K,V)\mathrm{.shape}=N\times HW\times 2D Attention(Q,K,V).shape=N×HW×2D
  5. 对输出进行处理
    先进行整形,将其整型为 N 2 D × H × W N2D\times H\times W N2D×H×W
    再使用Hardswish激活函数进行计算
    最后使用point convolution将维度由 N 2 D × H × W N2D\times H\times W N2D×H×W改为 C × H × W C\times H\times W C×H×W
    H a r d s w i s h ( x ) = { 0 x ≤ − 3 1 x ≥ + 3 x ⋅ ( x + 3 ) / 6 otherwise \mathrm{Hardswish}(x)= \begin{cases} 0& x\leq -3 \\ 1& x\geq +3\\ x\cdot(x+3)/6& \text{otherwise} \end{cases} Hardswish(x)=01x(x+3)/6x3x+3otherwise
进行down sampling(Shrink Attention)

假设输入尺寸为 C × H × W C\times H\times W C×H×W

  1. 计算 Q 、 K 、 V Q、K、V QKV
    对原图进行point convolution,得到Q、K、V;求Q时,步长设为2
    Q . s h a p e = N D × H 2 × W 2 Q\mathrm{.shape}=ND\times \frac{H}{2}\times \frac{W}{2} Q.shape=ND×2H×2W
    K . s h a p e = N D × H × W K\mathrm{.shape}=ND\times H\times W K.shape=ND×H×W
    V . s h a p e = N 4 D × H × W V\mathrm{.shape}=N4D\times H\times W V.shape=N4D×H×W
    然后将其整形
    Q . s h a p e = N × H W 4 × D Q\mathrm{.shape}=N\times \frac{HW}{4}\times D Q.shape=N×4HW×D
    K . s h a p e = N × H W × D K\mathrm{.shape}=N\times HW\times D K.shape=N×HW×D
    V . s h a p e = N × H W × 4 D V\mathrm{.shape}=N\times HW\times 4D V.shape=N×HW×4D
  2. 计算attention bias( B B B)
    用attention bias代替ViT中的pos_embedding,使得每一层都会代入位置信息(文中引用了他人的观点,论证加入位置信息后能提高Transformer的准确率)
    self.pos_bias = nn.Embedding(fmap_size * fmap_size, heads)
    # fmap_size = 14
    q_range = torch.arange(0, fmap_size, step = 2)
    k_range = torch.arange(fmap_size)
    
    q_pos = torch.stack(torch.meshgrid(q_range, q_range), dim = -1)
    k_pos = torch.stack(torch.meshgrid(k_range, k_range), dim = -1)
    # q_pos: torch.Size([7, 7, 2])
    # k_pos: torch.Size([14, 14, 2])
    
    q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
    # rearranged q_pos: torch.Size([49, 2])
    # rearranged k_pos: torch.Size([196, 2])
    
    rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()
     # rel_pos: torch.Size([49, 196, 2])
     
    x_rel, y_rel = rel_pos.unbind(dim = -1)
    # x_rel: torch.Size([49, 196])
    # y_rel: torch.Size([49, 196])
    
    pos_indices = (x_rel * fmap_size) + y_rel
    # pos_indices: torch.Size([49, 196])
    
    pos_indices的尺寸为 H W 4 × H W \frac{HW}{4}\times HW 4HW×HW,将其扩成 N × H W 4 × H W N\times \frac{HW}{4}\times HW N×4HW×HW
  3. 计算Attention weight( A A A)
    A = S o f t m a x ( ( Q K T + B ) / D A=\mathrm{Softmax}((QK^T+B)/\sqrt{D} A=Softmax((QKT+B)/D
    A . s h a p e = N × H W 4 × H W A\mathrm{.shape}=N\times \frac{HW}{4}\times HW A.shape=N×4HW×HW
  4. 计算Attention Value
    A t t e n t i o n ( Q , K , V ) = A V \mathrm{Attention}(Q,K,V)=AV Attention(Q,K,V)=AV
    A t t e n t i o n ( Q , K , V ) . s h a p e = N × H W × 4 D \mathrm{Attention}(Q,K,V)\mathrm{.shape}=N\times HW\times 4D Attention(Q,K,V).shape=N×HW×4D
  5. 对输出进行处理
    先进行整形,将其整型为 N 4 D × H 2 × W 2 N4D\times \frac{H}{2}\times \frac{W}{2} N4D×2H×2W
    再使用Hardswish激活函数进行计算
    最后使用point convolution将维度由 N 4 D × H 2 × W 2 N4D\times \frac{H}{2}\times \frac{W}{2} N4D×2H×2W改为 C ′ × H 2 × W 2 C'\times \frac{H}{2}\times \frac{W}{2} C×2H×2W
    H a r d s w i s h ( x ) = { 0 x ≤ − 3 1 x ≥ + 3 x ⋅ ( x + 3 ) / 6 otherwise \mathrm{Hardswish}(x)= \begin{cases} 0& x\leq -3 \\ 1& x\geq +3\\ x\cdot(x+3)/6& \text{otherwise} \end{cases} Hardswish(x)=01x(x+3)/6x3x+3otherwise

MLP

MLP是比较简单的2层point convolution的组合

class FeedForward(nn.Module):
   def __init__(self, dim, mult, dropout = 0.):
       super().__init__()
       self.net = nn.Sequential(
           nn.Conv2d(dim, dim * mult, 1),
           nn.GELU(),
           nn.Dropout(dropout),
           nn.Conv2d(dim * mult, dim, 1),
           nn.Dropout(dropout)
       )
   def forward(self, x):
       return self.net(x)

代码实现

class Attention(nn.Module):
    def __init__(self, dim, fmap_size, heads = 8, dim_key = 32, dim_value = 64, dropout = 0., dim_out = None, downsample = False):
        super().__init__()
        inner_dim_key = dim_key *  heads
        inner_dim_value = dim_value *  heads
        dim_out = default(dim_out, dim)

        self.heads = heads
        self.scale = dim_key ** -0.5

        self.to_q = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, stride = (2 if downsample else 1), bias = False), nn.BatchNorm2d(inner_dim_key))
        self.to_k = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, bias = False), nn.BatchNorm2d(inner_dim_key))
        self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value))

        self.attend = nn.Softmax(dim = -1)

        out_batch_norm = nn.BatchNorm2d(dim_out)
        nn.init.zeros_(out_batch_norm.weight)

        self.to_out = nn.Sequential(
            nn.GELU(),
            nn.Conv2d(inner_dim_value, dim_out, 1),
            out_batch_norm,
            nn.Dropout(dropout)
        )

        # positional bias

        self.pos_bias = nn.Embedding(fmap_size * fmap_size, heads)

        q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1))
        k_range = torch.arange(fmap_size)

        q_pos = torch.stack(torch.meshgrid(q_range, q_range), dim = -1)
        k_pos = torch.stack(torch.meshgrid(k_range, k_range), dim = -1)

        q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
        rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()

        x_rel, y_rel = rel_pos.unbind(dim = -1)
        pos_indices = (x_rel * fmap_size) + y_rel

        self.register_buffer('pos_indices', pos_indices)

    def apply_pos_bias(self, fmap):
        bias = self.pos_bias(self.pos_indices)
        bias = rearrange(bias, 'i j h -> () h i j')
        return fmap + (bias / self.scale)

    def forward(self, x):
        b, n, *_, h = *x.shape, self.heads

        q = self.to_q(x)
        y = q.shape[2]

        qkv = (q, self.to_k(x), self.to_v(x))
        q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        dots = self.apply_pos_bias(dots)

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y)
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult = 2, dropout = 0., dim_out = None, downsample = False):
        super().__init__()
        dim_out = default(dim_out, dim)
        self.layers = nn.ModuleList([])
        self.attn_residual = (not downsample) and dim == dim_out

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, fmap_size = fmap_size, heads = heads, dim_key = dim_key, dim_value = dim_value, dropout = dropout, downsample = downsample, dim_out = dim_out),
                FeedForward(dim_out, mlp_mult, dropout = dropout)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            attn_res = (x if self.attn_residual else 0)
            x = attn(x) + attn_res
            x = ff(x) + x
        return x

class LeViT(nn.Module):
    def __init__(
        self,
        *,
        image_size,
        num_classes,
        dim,
        depth,
        heads,
        mlp_mult,
        stages = 3,
        dim_key = 32,
        dim_value = 64,
        dropout = 0.,
        num_distill_classes = None
    ):
        super().__init__()

        dims = cast_tuple(dim, stages)
        depths = cast_tuple(depth, stages)
        layer_heads = cast_tuple(heads, stages)

        assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'

        self.conv_embedding = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride = 2, padding = 1),
            nn.Conv2d(32, 64, 3, stride = 2, padding = 1),
            nn.Conv2d(64, 128, 3, stride = 2, padding = 1),
            nn.Conv2d(128, dims[0], 3, stride = 2, padding = 1)
        )

        fmap_size = image_size // (2 ** 4)
        layers = []

        for ind, dim, depth, heads in zip(range(stages), dims, depths, layer_heads):
            is_last = ind == (stages - 1)
            layers.append(Transformer(dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult, dropout))

            if not is_last:
                next_dim = dims[ind + 1]
                layers.append(Transformer(dim, fmap_size, 1, heads * 2, dim_key, dim_value, dim_out = next_dim, downsample = True))
                fmap_size = ceil(fmap_size / 2)

        self.backbone = nn.Sequential(*layers)

        self.pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            Rearrange('... () () -> ...')
        )

        self.distill_head = nn.Linear(dim, num_distill_classes) if exists(num_distill_classes) else always(None)
        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        x = self.conv_embedding(img)

        x = self.backbone(x)        

        x = self.pool(x)

        out = self.mlp_head(x)
        distill = self.distill_head(x)

        if exists(distill):
            return out, distill

        return out

CvT(卷积+Transformer)

整体架构

注意: 在论文的图片里,我们能看到在stage3阶段我们加入了cls_token,并在最后使用的是cls_token进行的分类。但是在源码中,作者并不是这么实现的,而是使用自适应均值平均(AdaptiveAvePool2d) 函数的输出作为分类器的输入。

整体步骤可以分为3个stage
一个stage内有2个模块,分别被称作Convolutional Token EmbeddingConvolutional Transformer Blocks

详细模块

Convolutional Token Embedding

这个模块就是一个卷积层,目的是提取特征信息,降低图像大小
执行完卷积层之后会进行LayerNorm

Convolutional Transformer Blocks

这个模块与ViT的Transformer部分不同,具体体现在Q、K、V的生成方式

在ViT中,我们先将2D的信息重组为序列信息,然后使用线性映射得到Q、K、V
x i q / k / v = L i n e a r ( x i ) x_i^{q/k/v}=\mathrm{Linear}(x_i) xiq/k/v=Linear(xi)

在CvT中,我们直接使用2D的信息,使用卷积操作得到Q、K、V的2D表达,然后再拉伸成序列信息的形式。
x i q / k / v = F l a t t e n ( C o n v 2 d ( R e s h a p e 2 D ( x i ) , s ) ) x_i^{q/k/v}=\mathrm{Flatten}(\mathrm{Conv2d}(\mathrm{Reshape2D}(x_i),s)) xiq/k/v=Flatten(Conv2d(Reshape2D(xi),s))
计算完成后,我们还会将序列信息重组为2D信息,以进行接下来的操作。
在阅读源码时,我们发现计算Attention Value时,我们并没有加上位置信息,也就是ViT中的pos_embedding,论文中也给了解释:
因为进行卷积映射时,我们在无形中也将各个元素与周围元素之间的联系保留了下来,因此pos_embedding也就没那么重要了。

代码实现

class DepthWiseConv2d(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
            nn.BatchNorm2d(dim_in),
            nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
        )
    def forward(self, x):
        return self.net(x)

这个DepthWiseConv2d函数是Xception结构,可以与PiT的DepthWiseConv2d函数进行比较

class Attention(nn.Module):
    def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        padding = proj_kernel // 2
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
		
		# 卷积映射
        self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
        self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        shape = x.shape
        b, n, _, y, h = *shape, self.heads
        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
        q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))

        dots = einsum('b i d, b j d -> b i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
        return self.to_out(out)
class CvT(nn.Module):
    def __init__(
        self,
        *,
        num_classes,
        s1_emb_dim = 64, s1_emb_kernel = 7,  s1_emb_stride = 4, s1_proj_kernel = 3, s1_kv_proj_stride = 2, s1_heads = 1, s1_depth = 1, s1_mlp_mult = 4,
        s2_emb_dim = 192, s2_emb_kernel = 3, s2_emb_stride = 2, s2_proj_kernel = 3, s2_kv_proj_stride = 2, s2_heads = 3, s2_depth = 2, s2_mlp_mult = 4,
        s3_emb_dim = 384, s3_emb_kernel = 3, s3_emb_stride = 2, s3_proj_kernel = 3, s3_kv_proj_stride = 2, s3_heads = 6, s3_depth = 10, s3_mlp_mult = 4,
        dropout = 0.
    ):
        super().__init__()
        kwargs = dict(locals())

        dim = 3
        layers = []

		# 每个Stage内的部分
        for prefix in ('s1', 's2', 's3'):
            config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs)

            layers.append(nn.Sequential(
                nn.Conv2d(dim, config['emb_dim'], kernel_size = config['emb_kernel'], padding = (config['emb_kernel'] // 2), stride = config['emb_stride']),
                LayerNorm(config['emb_dim']),
                Transformer(dim = config['emb_dim'], proj_kernel = config['proj_kernel'], kv_proj_stride = config['kv_proj_stride'], depth = config['depth'], heads = config['heads'], mlp_mult = config['mlp_mult'], dropout = dropout)
            ))

            dim = config['emb_dim']

        self.layers = nn.Sequential(
            *layers,
            # 最后经过自适应均值平均池化而非使用cls_token进行分类
            nn.AdaptiveAvgPool2d(1),
            Rearrange('... () () -> ...'),
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):
        return self.layers(x)

  • 20
    点赞
  • 103
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值