DeepViT、DeiT、CaiT、T2T、Cross-ViT、PiT、LeViT、CvT算法整理

改变函数不变网络结构


DeepVit(改变Attention Value函数)

发现

  1. 在原来的Vit模型中,随着transformer层的增多,结果并没有得到优化,甚至会出现下降(32层的差于24层的)
  2. 随着ViT模型深度的增加,注意力图在一定深层数的计算之后会趋于相似
  3. 这个现象被称作attention collapse

解决方法

计算value时,将各个head的attention weight进行交换(用一个可学习的矩阵实现)
原Attention Value计算公式:
A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T d ) V \mathrm{Attention}(Q,K,V)=\mathrm{Softmax}(\frac{QK^T}{\sqrt d})V Attention(Q,K,V)=Softmax(d QKT)V
新的Re-Attention Value计算公式:
R e − A t t e n t i o n ( Q , K , V ) = N o r m ( θ T ( S o f t m a x ( Q K T d ) ) ) V \mathrm{Re-Attention}(Q,K,V)=\mathrm{Norm}(\theta^T(\mathrm{Softmax}(\frac{QK^T}{\sqrt d})))V ReAttention(Q,K,V)=Norm(θT(Softmax(d QKT)))V

代码

  • Re-Attention类
class Re-Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.reattn_weights = nn.Parameter(torch.randn(heads, heads))
        # 待学习的转换矩阵参数

        self.reattn_norm = nn.Sequential(
            Rearrange('b h i j -> b i j h'),
            nn.LayerNorm(heads),
            Rearrange('b i j h -> b h i j')
        )
        # 新加入的Norm层

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        # Calculate attention weight

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        attn = dots.softmax(dim=-1)

        # re-attention

        attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights)
        attn = self.reattn_norm(attn)
        # 计算新的attention weight

        # aggregate and out

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

目的与优点

  • 目的是用来替代dropout这种过于简单的取舍
  • 优点为实现方式简单

DeiT(改变loss)

如何应用

  1. 将一个表现较好的模型作为teacher:经过实验,发现使用CNN作为teacher表现更好
  2. 将ViT模型作为student
  3. 在输入的尾部增加一个distillation token作为计算distillation的依据
  4. 将计算出的distillation的结果作为loss

创新点

提出2种distill方式(两个计算distillation的公式):

  • Soft distillation
    L g l o b a l = ( 1 − λ ) L C E ( S o f t m a x ( Z s ) , y ) + λ τ 2 K L ( S o f t m a x ( Z s / τ ) , S o f t m a x ( Z t / τ ) ) L_{global} = (1-\lambda)L_{CE}(\mathrm{Softmax}(Z_s),y)+\lambda\tau^2\mathrm{KL}(\mathrm{Softmax}(Z_s/\tau),\mathrm{Softmax}(Z_t/\tau)) Lglobal=(1λ)LCE(Softmax(Zs),y)+λτ2KL(Softmax(Zs/τ),Softmax(Zt/τ))

Z t Z_t Zt是teacher model的结果
Z s Z_s Zs是student model的结果
y y y是真实值
L C E L_{CE} LCE是cross entropy, K L \mathrm{KL} KL是Kullback-Leibler divergence
τ \tau τ是蒸馏温度

  • Hard-label distillation
    L g l o b a l h a r d D i s t i l l = 1 2 L C E ( S o f t m a x ( Z s ) , y ) + 1 2 L C E ( S o f t m a x ( Z s ) , a r g m a x ( Z t ) ) L_{global}^{\mathrm{hardDistill}} = \frac{1}{2}L_{CE}(\mathrm{Softmax}(Z_s),y)+\frac{1}{2}L_{CE}(\mathrm{Softmax}(Z_s),\mathrm{argmax}(Z_t)) LglobalhardDistill=21LCE(Softmax(Zs),y)+21LCE(Softmax(Zs),argmax(Zt))

代码体现

  • 将Distillation token加入输入
def exists(val):
    return val is not None
# classes
class DistillMixin: 
    def forward(self, img, distill_token = None):
        distilling = exists(distill_token)
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim = 1)
        x += self.pos_embedding[:, :(n + 1)]

        if distilling:
            distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
            x = torch.cat((x, distill_tokens), dim = 1)
            # 将Distillation token加到最后面

        x = self._attend(x)

        if distilling:
            x, distill_tokens = x[:, :-1], x[:, -1]

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        out = self.mlp_head(x)

        if distilling:
            return out, distill_tokens

        return out
  • 计算loss
class DistillWrapper(nn.Module):
    def __init__(
        self,
        *,
        teacher,
        student,
        temperature = 1.,
        alpha = 0.5,
        hard = False
    ):
        super().__init__()
        assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'

        self.teacher = teacher
        self.student = student

        dim = student.dim
        num_classes = student.num_classes
        self.temperature = temperature
        self.alpha = alpha
        self.hard = hard

        self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))

        self.distill_mlp = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img, labels, temperature = None, alpha = None, **kwargs):
        b, *_ = img.shape
        alpha = alpha if exists(alpha) else self.alpha
        T = temperature if exists(temperature) else self.temperature

        with torch.no_grad():
            teacher_logits = self.teacher(img)

        student_logits, distill_tokens = self.student(img, distill_token = self.distillation_token, **kwargs)
        distill_logits = self.distill_mlp(distill_tokens)

        loss = F.cross_entropy(student_logits, labels)

        if not self.hard:
            distill_loss = F.kl_div(
                F.log_softmax(distill_logits / T, dim = -1),
                F.softmax(teacher_logits / T, dim = -1).detach(),
            reduction = 'batchmean')
            distill_loss *= T ** 2

        else:
            teacher_labels = teacher_logits.argmax(dim = -1)
            distill_loss = F.cross_entropy(student_logits, teacher_labels)

        return loss * (1 - alpha) + distill_loss * alpha

改变网络结构

CaiT

1. 改变残差网络(LayerScale+Class-Attention)

结构比较
  • 原来的残差网络

原来的残差网络公式:
x l ′ = x l + S A ( x l ) x l + 1 = x l ′ + F F N ( x l ′ ) x_l'=x_l+\mathrm{SA}(x_l) \\x_{l+1}=x_l'+\mathrm{FFN}(x_l') xl=xl+SA(xl)xl+1=xl+FFN(xl)

  • 新的残差网络

新的残差网络公式:
x l ′ = x l + d i a g ( λ l , 1 . . . λ l , d ) S A ( x l ) x l + 1 = x l ′ + d i a g ( λ l , 1 ′ . . . λ l , d ′ ) F F N ( x l ′ ) x_l'=x_l+\mathrm{diag}(\lambda_{l,1}...\lambda_{l,d})\mathrm{SA}(x_l) \\x_{l+1}=x_l'+\mathrm{diag}(\lambda^{'}_{l,1}...\lambda^{'}_{l,d})\mathrm{FFN}(x_l') xl=xl+diag(λl,1...λl,d)SA(xl)xl+1=xl+diag(λl,1...λl,d)FFN(xl)
这里的权重 λ \lambda λ是一个初始值很小的可学习的参数

代码体现
  1. 增加了函数
class LayerScale(nn.Module):
    def __init__(self, dim, fn, depth):
        super().__init__()
        if depth <= 18:  # epsilon detailed in section 2 of paper
            init_eps = 0.1
        elif depth > 18 and depth <= 24:
            init_eps = 1e-5
        else:
            init_eps = 1e-6

        scale = torch.zeros(1, 1, dim).fill_(init_eps)
        self.scale = nn.Parameter(scale)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) * self.scale
  1. 修改了残差网络
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.layer_dropout = layer_dropout

        for ind in range(depth):
            self.layers.append(nn.ModuleList([
                LayerScale(dim, PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), depth = ind + 1),
                LayerScale(dim, PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)), depth = ind + 1)
            ]))
            # 这里原来没有LayerScale,就只有
            # PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)
            # PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)
意义

通过加入一个特别小的、可学习的参数 λ \lambda λ,可以有效降低每层计算后对原信息的变化,从而减缓网络层数变大时

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值