改变函数不变网络结构
DeepVit(改变Attention Value函数)
发现
- 在原来的Vit模型中,随着transformer层的增多,结果并没有得到优化,甚至会出现下降(32层的差于24层的)
- 随着ViT模型深度的增加,注意力图在一定深层数的计算之后会趋于相似
- 这个现象被称作attention collapse
解决方法
计算value时,将各个head的attention weight进行交换(用一个可学习的矩阵实现)
原Attention Value计算公式:
A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T d ) V \mathrm{Attention}(Q,K,V)=\mathrm{Softmax}(\frac{QK^T}{\sqrt d})V Attention(Q,K,V)=Softmax(dQKT)V
新的Re-Attention Value计算公式:
R e − A t t e n t i o n ( Q , K , V ) = N o r m ( θ T ( S o f t m a x ( Q K T d ) ) ) V \mathrm{Re-Attention}(Q,K,V)=\mathrm{Norm}(\theta^T(\mathrm{Softmax}(\frac{QK^T}{\sqrt d})))V Re−Attention(Q,K,V)=Norm(θT(Softmax(dQKT)))V
代码
- Re-Attention类
class Re-Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.reattn_weights = nn.Parameter(torch.randn(heads, heads))
# 待学习的转换矩阵参数
self.reattn_norm = nn.Sequential(
Rearrange('b h i j -> b i j h'),
nn.LayerNorm(heads),
Rearrange('b i j h -> b h i j')
)
# 新加入的Norm层
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
# Calculate attention weight
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = dots.softmax(dim=-1)
# re-attention
attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights)
attn = self.reattn_norm(attn)
# 计算新的attention weight
# aggregate and out
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
目的与优点
- 目的是用来替代dropout这种过于简单的取舍
- 优点为实现方式简单
DeiT(改变loss)
如何应用
- 将一个表现较好的模型作为teacher:经过实验,发现使用CNN作为teacher表现更好
- 将ViT模型作为student
- 在输入的尾部增加一个distillation token作为计算distillation的依据
- 将计算出的distillation的结果作为loss

创新点
提出2种distill方式(两个计算distillation的公式):
- Soft distillation
L g l o b a l = ( 1 − λ ) L C E ( S o f t m a x ( Z s ) , y ) + λ τ 2 K L ( S o f t m a x ( Z s / τ ) , S o f t m a x ( Z t / τ ) ) L_{global} = (1-\lambda)L_{CE}(\mathrm{Softmax}(Z_s),y)+\lambda\tau^2\mathrm{KL}(\mathrm{Softmax}(Z_s/\tau),\mathrm{Softmax}(Z_t/\tau)) Lglobal=(1−λ)LCE(Softmax(Zs),y)+λτ2KL(Softmax(Zs/τ),Softmax(Zt/τ))
Z t Z_t Zt是teacher model的结果
Z s Z_s Zs是student model的结果
y y y是真实值
L C E L_{CE} LCE是cross entropy, K L \mathrm{KL} KL是Kullback-Leibler divergence
τ \tau τ是蒸馏温度
- Hard-label distillation
L g l o b a l h a r d D i s t i l l = 1 2 L C E ( S o f t m a x ( Z s ) , y ) + 1 2 L C E ( S o f t m a x ( Z s ) , a r g m a x ( Z t ) ) L_{global}^{\mathrm{hardDistill}} = \frac{1}{2}L_{CE}(\mathrm{Softmax}(Z_s),y)+\frac{1}{2}L_{CE}(\mathrm{Softmax}(Z_s),\mathrm{argmax}(Z_t)) LglobalhardDistill=21LCE(Softmax(Zs),y)+21LCE(Softmax(Zs),argmax(Zt))
代码体现
- 将Distillation token加入输入
def exists(val):
return val is not None
# classes
class DistillMixin:
def forward(self, img, distill_token = None):
distilling = exists(distill_token)
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
if distilling:
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1)
# 将Distillation token加到最后面
x = self._attend(x)
if distilling:
x, distill_tokens = x[:, :-1], x[:, -1]
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
out = self.mlp_head(x)
if distilling:
return out, distill_tokens
return out
- 计算loss
class DistillWrapper(nn.Module):
def __init__(
self,
*,
teacher,
student,
temperature = 1.,
alpha = 0.5,
hard = False
):
super().__init__()
assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'
self.teacher = teacher
self.student = student
dim = student.dim
num_classes = student.num_classes
self.temperature = temperature
self.alpha = alpha
self.hard = hard
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
self.distill_mlp = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img, labels, temperature = None, alpha = None, **kwargs):
b, *_ = img.shape
alpha = alpha if exists(alpha) else self.alpha
T = temperature if exists(temperature) else self.temperature
with torch.no_grad():
teacher_logits = self.teacher(img)
student_logits, distill_tokens = self.student(img, distill_token = self.distillation_token, **kwargs)
distill_logits = self.distill_mlp(distill_tokens)
loss = F.cross_entropy(student_logits, labels)
if not self.hard:
distill_loss = F.kl_div(
F.log_softmax(distill_logits / T, dim = -1),
F.softmax(teacher_logits / T, dim = -1).detach(),
reduction = 'batchmean')
distill_loss *= T ** 2
else:
teacher_labels = teacher_logits.argmax(dim = -1)
distill_loss = F.cross_entropy(student_logits, teacher_labels)
return loss * (1 - alpha) + distill_loss * alpha
改变网络结构
CaiT
1. 改变残差网络(LayerScale+Class-Attention)
结构比较
- 原来的残差网络

原来的残差网络公式:
x l ′ = x l + S A ( x l ) x l + 1 = x l ′ + F F N ( x l ′ ) x_l'=x_l+\mathrm{SA}(x_l) \\x_{l+1}=x_l'+\mathrm{FFN}(x_l') xl′=xl+SA(xl)xl+1=xl′+FFN(xl′)
- 新的残差网络

新的残差网络公式:
x l ′ = x l + d i a g ( λ l , 1 . . . λ l , d ) S A ( x l ) x l + 1 = x l ′ + d i a g ( λ l , 1 ′ . . . λ l , d ′ ) F F N ( x l ′ ) x_l'=x_l+\mathrm{diag}(\lambda_{l,1}...\lambda_{l,d})\mathrm{SA}(x_l) \\x_{l+1}=x_l'+\mathrm{diag}(\lambda^{'}_{l,1}...\lambda^{'}_{l,d})\mathrm{FFN}(x_l') xl′=xl+diag(λl,1...λl,d)SA(xl)xl+1=xl′+diag(λl,1′...λl,d′)FFN(xl′)
这里的权重 λ \lambda λ是一个初始值很小的可学习的参数
代码体现
- 增加了函数
class LayerScale(nn.Module):
def __init__(self, dim, fn, depth):
super().__init__()
if depth <= 18: # epsilon detailed in section 2 of paper
init_eps = 0.1
elif depth > 18 and depth <= 24:
init_eps = 1e-5
else:
init_eps = 1e-6
scale = torch.zeros(1, 1, dim).fill_(init_eps)
self.scale = nn.Parameter(scale)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
- 修改了残差网络
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
self.layer_dropout = layer_dropout
for ind in range(depth):
self.layers.append(nn.ModuleList([
LayerScale(dim, PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), depth = ind + 1),
LayerScale(dim, PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)), depth = ind + 1)
]))
# 这里原来没有LayerScale,就只有
# PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)
# PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)
意义
通过加入一个特别小的、可学习的参数 λ \lambda λ,可以有效降低每层计算后对原信息的变化,从而减缓网络层数变大时