改变函数不变网络结构
DeepVit(改变Attention Value函数)
发现
- 在原来的Vit模型中,随着transformer层的增多,结果并没有得到优化,甚至会出现下降(32层的差于24层的)
- 随着ViT模型深度的增加,注意力图在一定深层数的计算之后会趋于相似
- 这个现象被称作attention collapse
解决方法
计算value时,将各个head的attention weight进行交换(用一个可学习的矩阵实现)
原Attention Value计算公式:
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
S
o
f
t
m
a
x
(
Q
K
T
d
)
V
\mathrm{Attention}(Q,K,V)=\mathrm{Softmax}(\frac{QK^T}{\sqrt d})V
Attention(Q,K,V)=Softmax(dQKT)V
新的Re-Attention Value计算公式:
R
e
−
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
N
o
r
m
(
θ
T
(
S
o
f
t
m
a
x
(
Q
K
T
d
)
)
)
V
\mathrm{Re-Attention}(Q,K,V)=\mathrm{Norm}(\theta^T(\mathrm{Softmax}(\frac{QK^T}{\sqrt d})))V
Re−Attention(Q,K,V)=Norm(θT(Softmax(dQKT)))V
代码
- Re-Attention类
class Re-Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.reattn_weights = nn.Parameter(torch.randn(heads, heads))
# 待学习的转换矩阵参数
self.reattn_norm = nn.Sequential(
Rearrange('b h i j -> b i j h'),
nn.LayerNorm(heads),
Rearrange('b i j h -> b h i j')
)
# 新加入的Norm层
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
# Calculate attention weight
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = dots.softmax(dim=-1)
# re-attention
attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights)
attn = self.reattn_norm(attn)
# 计算新的attention weight
# aggregate and out
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
目的与优点
- 目的是用来替代dropout这种过于简单的取舍
- 优点为实现方式简单
DeiT(改变loss)
如何应用
- 将一个表现较好的模型作为teacher:经过实验,发现使用CNN作为teacher表现更好
- 将ViT模型作为student
- 在输入的尾部增加一个distillation token作为计算distillation的依据
- 将计算出的distillation的结果作为loss
创新点
提出2种distill方式(两个计算distillation的公式):
- Soft distillation
L g l o b a l = ( 1 − λ ) L C E ( S o f t m a x ( Z s ) , y ) + λ τ 2 K L ( S o f t m a x ( Z s / τ ) , S o f t m a x ( Z t / τ ) ) L_{global} = (1-\lambda)L_{CE}(\mathrm{Softmax}(Z_s),y)+\lambda\tau^2\mathrm{KL}(\mathrm{Softmax}(Z_s/\tau),\mathrm{Softmax}(Z_t/\tau)) Lglobal=(1−λ)LCE(Softmax(Zs),y)+λτ2KL(Softmax(Zs/τ),Softmax(Zt/τ))
Z t Z_t Zt是teacher model的结果
Z s Z_s Zs是student model的结果
y y y是真实值
L C E L_{CE} LCE是cross entropy, K L \mathrm{KL} KL是Kullback-Leibler divergence
τ \tau τ是蒸馏温度
- Hard-label distillation
L g l o b a l h a r d D i s t i l l = 1 2 L C E ( S o f t m a x ( Z s ) , y ) + 1 2 L C E ( S o f t m a x ( Z s ) , a r g m a x ( Z t ) ) L_{global}^{\mathrm{hardDistill}} = \frac{1}{2}L_{CE}(\mathrm{Softmax}(Z_s),y)+\frac{1}{2}L_{CE}(\mathrm{Softmax}(Z_s),\mathrm{argmax}(Z_t)) LglobalhardDistill=21LCE(Softmax(Zs),y)+21LCE(Softmax(Zs),argmax(Zt))
代码体现
- 将Distillation token加入输入
def exists(val):
return val is not None
# classes
class DistillMixin:
def forward(self, img, distill_token = None):
distilling = exists(distill_token)
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
if distilling:
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1)
# 将Distillation token加到最后面
x = self._attend(x)
if distilling:
x, distill_tokens = x[:, :-1], x[:, -1]
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
out = self.mlp_head(x)
if distilling:
return out, distill_tokens
return out
- 计算loss
class DistillWrapper(nn.Module):
def __init__(
self,
*,
teacher,
student,
temperature = 1.,
alpha = 0.5,
hard = False
):
super().__init__()
assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'
self.teacher = teacher
self.student = student
dim = student.dim
num_classes = student.num_classes
self.temperature = temperature
self.alpha = alpha
self.hard = hard
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
self.distill_mlp = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img, labels, temperature = None, alpha = None, **kwargs):
b, *_ = img.shape
alpha = alpha if exists(alpha) else self.alpha
T = temperature if exists(temperature) else self.temperature
with torch.no_grad():
teacher_logits = self.teacher(img)
student_logits, distill_tokens = self.student(img, distill_token = self.distillation_token, **kwargs)
distill_logits = self.distill_mlp(distill_tokens)
loss = F.cross_entropy(student_logits, labels)
if not self.hard:
distill_loss = F.kl_div(
F.log_softmax(distill_logits / T, dim = -1),
F.softmax(teacher_logits / T, dim = -1).detach(),
reduction = 'batchmean')
distill_loss *= T ** 2
else:
teacher_labels = teacher_logits.argmax(dim = -1)
distill_loss = F.cross_entropy(student_logits, teacher_labels)
return loss * (1 - alpha) + distill_loss * alpha
改变网络结构
CaiT
1. 改变残差网络(LayerScale+Class-Attention)
结构比较
- 原来的残差网络
原来的残差网络公式:
x l ′ = x l + S A ( x l ) x l + 1 = x l ′ + F F N ( x l ′ ) x_l'=x_l+\mathrm{SA}(x_l) \\x_{l+1}=x_l'+\mathrm{FFN}(x_l') xl′=xl+SA(xl)xl+1=xl′+FFN(xl′)
- 新的残差网络
新的残差网络公式:
x l ′ = x l + d i a g ( λ l , 1 . . . λ l , d ) S A ( x l ) x l + 1 = x l ′ + d i a g ( λ l , 1 ′ . . . λ l , d ′ ) F F N ( x l ′ ) x_l'=x_l+\mathrm{diag}(\lambda_{l,1}...\lambda_{l,d})\mathrm{SA}(x_l) \\x_{l+1}=x_l'+\mathrm{diag}(\lambda^{'}_{l,1}...\lambda^{'}_{l,d})\mathrm{FFN}(x_l') xl′=xl+diag(λl,1...λl,d)SA(xl)xl+1=xl′+diag(λl,1′...λl,d′)FFN(xl′)
这里的权重 λ \lambda λ是一个初始值很小的可学习的参数
代码体现
- 增加了函数
class LayerScale(nn.Module):
def __init__(self, dim, fn, depth):
super().__init__()
if depth <= 18: # epsilon detailed in section 2 of paper
init_eps = 0.1
elif depth > 18 and depth <= 24:
init_eps = 1e-5
else:
init_eps = 1e-6
scale = torch.zeros(1, 1, dim).fill_(init_eps)
self.scale = nn.Parameter(scale)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
- 修改了残差网络
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
self.layer_dropout = layer_dropout
for ind in range(depth):
self.layers.append(nn.ModuleList([
LayerScale(dim, PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), depth = ind + 1),
LayerScale(dim, PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)), depth = ind + 1)
]))
# 这里原来没有LayerScale,就只有
# PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)
# PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)
意义
通过加入一个特别小的、可学习的参数 λ \lambda λ,可以有效降低每层计算后对原信息的变化,从而减缓网络层数变大时发生的过拟合现象
2. Class-Attention结构
结构比较
最初的ViT模型中,class token从一开始就跟着特征信息进入网络进行训练的
ViT中计算Attention Value时Q K V的获取公式:
Q = W q z + b q Q=W_qz+b_q Q=Wqz+bq
K = W k z + b k K=W_kz+b_k K=Wkz+bk
V = W v z + b v V=W_vz+b_v V=Wvz+bv
z = [ x c l a s s , x p a t c h e s ] z=[x_{class},x_{patches}] z=[xclass,xpatches]
而在CaiT中,将特征学习和类别学习分开,先只对特征进行学习,要结束时(demo中是12个self层,2个class层),才将class token代入,此时特征只进行注意力的计算,不通过前馈神经网络
self-attention中Q K V的获取公式:
Q = W q x p a t c h e s + b q Q=W_qx_{patches}+b_q Q=Wqxpatches+bq
K = W k x p a t c h e s + b k K=W_kx_{patches}+b_k K=Wkxpatches+bk
V = W v x p a t c h e s + b v V=W_vx_{patches}+b_v V=Wvxpatches+bv
class-attention中Q K V的获取公式:
Q = W q x c l a s s + b q Q=W_qx_{class}+b_q Q=Wqxclass+bq
K = W k z + b k K=W_kz+b_k K=Wkz+bk
V = W v z + b v V=W_vz+b_v V=Wvz+bv
z = [ x c l a s s , x p a t c h e s ] z=[x_{class},x_{patches}] z=[xclass,xpatches]
代码体现
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
self.layer_dropout = layer_dropout
for ind in range(depth):
self.layers.append(nn.ModuleList([
LayerScale(dim, PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), depth = ind + 1),
LayerScale(dim, PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)), depth = ind + 1)
]))
def forward(self, x, context = None):
layers = dropout_layers(self.layers, dropout = self.layer_dropout)
for attn, ff in layers:
x = attn(x, context = context) + x
# 当context为None时:x是$x_patches$
# 当context不为None时:context是$x_patches$, x是$x_class$
x = ff(x) + x
return x
class CaiT(nn.Module):
def __init__(
self, *, image_size, patch_size, num_classes, dim, depth, cls_depth, heads, mlp_dim, dim_head = 64, dropout = 0., emb_dropout = 0., layer_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = 3 * patch_size ** 2
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.patch_transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
# 有2种transformer
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
x += self.pos_embedding[:, :n]
x = self.dropout(x)
x = self.patch_transformer(x)
# 先进行patch_transformer,此时未加入cla_token
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = self.cls_transformer(cls_tokens, context = x)
# 加入cls_token后再进行cls_transformer,得到的就是cls_token
return self.mlp_head(x[:, 0])
Token-to-Token ViT(Patch_embedding)
网络结构
1. 整体网络结构
分为2大块,可以简单归类为不带class_token的patch_embedding部分和带class_token的Transformer部分。带class_token部分的网络结构与ViT相同,变化较大的是前面新加的不带class_token部分的结构,作者将其命名为“Tokens-to-Token module”
2. Tokens-to-Token module
整个过程可以分为如下几步:
- 刚开始的时候会将 C C C通道的图片分为 ( h ⋅ w ) (h\cdot w) (h⋅w)个patch,每个patch的大小为 p 1 × p 2 p_1\times p_2 p1×p2然后把每个patch里的数据拉直,从而得到 ( h ⋅ w ) (h\cdot w) (h⋅w)个大小为 ( p 1 ⋅ p 2 ⋅ C ) (p_1\cdot p_2\cdot C) (p1⋅p2⋅C)的tokens
- 将这 ( h ⋅ w ) (h\cdot w) (h⋅w)个大小为 n n n的tokens进行重组,变成 n n n个大小为 h × w h\times w h×w的 图片
- 对这 n n n个大小为 h × w h\times w h×w图片进行Unfold处理,假设kernal的大小为 [ k 1 , k 2 ] [k_1,k_2] [k1,k2],则经过Unfold之后,我们能得到 l l l个大小为 n ⋅ k 1 ⋅ k 2 n\cdot k_1 \cdot k_2 n⋅k1⋅k2的tokens
l = [ h + 2 p − k 1 k 1 − s + 1 ] × [ w + 2 p − k 2 k 2 − s ] l=[\frac{h+2p-k_1}{k_1-s}+1]\times [\frac{w+2p-k_2}{k_2-s}] l=[k1−sh+2p−k1+1]×[k2−sw+2p−k2]
p p p是padding的大小
s s s是stride的大小
- 我们希望保留不同区域之间的关系,所以我们对这些tokens进行reshape,得到 n ⋅ k 1 ⋅ k 2 n\cdot k_1 \cdot k_2 n⋅k1⋅k2个大小为 l l l的tokens
- 将每个token代入到Transformer中计算,这样我们能够将各领域间的信息共享
- 重复step 2
Unfold操作的本质就是另一种意义上的reshape,不会进行学习和计算,就是把滑块包含区域内的Tokens拼接成一个Token
代码实现
def conv_output_size(image_size, kernel_size, stride, padding):
return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)
# classes
class RearrangeImage(nn.Module):
def forward(self, x):
return rearrange(x, 'b (h w) c -> b c h w', h = int(math.sqrt(x.shape[1])))
# main class
class T2TViT(nn.Module):
def __init__(self, *, image_size, num_classes, dim, depth = None, heads = None, mlp_dim = None, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., transformer = None, t2t_layers = ((7, 4), (3, 2), (3, 2))):
super().__init__()
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
layers = []
layer_dim = channels
output_image_size = image_size
for i, (kernel_size, stride) in enumerate(t2t_layers):
layer_dim *= kernel_size ** 2
is_first = i == 0
output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2)
layers.extend([
RearrangeImage() if not is_first else nn.Identity(),
nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2),
Rearrange('b c n -> b n c'),
Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout),
])
# 每一层layer都包含了4步:重组、Unfold、调序、计算
layers.append(nn.Linear(layer_dim, dim))
self.to_patch_embedding = nn.Sequential(*layers)
# patch_embeddingcong从一开始的Rearrange变得复杂
self.pos_embedding = nn.Parameter(torch.randn(1, output_image_size ** 2 + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
if not exists(transformer):
assert all([exists(depth), exists(heads), exists(mlp_dim)]), 'depth, heads, and mlp_dim must be supplied'
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
else:
self.transformer = transformer
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
# 只改变了这个patch_embedding
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
Cross-ViT(Multi-Scale Features)
整体结构
- 先将原图片按照大尺寸和小尺寸进行分割,这样就有了S-Batch(Small patch size Batch)和L-Batch(Large patch size Batch)
- 先各自进行Transformer计算
- 将得到的结果使用Cross-Attention方法进行计算
- 将不同patch size的检测结果相加就是该网络的给出的结果(代码里就是两个Batch的CLS token分别经过mlp_head函数之后,将两个得到的值相加)
def forward(self, img):
sm_tokens = self.sm_image_embedder(img)
lg_tokens = self.lg_image_embedder(img)
sm_tokens, lg_tokens = self.multi_scale_encoder(sm_tokens, lg_tokens)
sm_cls, lg_cls = map(lambda t: t[:, 0], (sm_tokens, lg_tokens))
sm_logits = self.sm_mlp_head(sm_cls)
lg_logits = self.lg_mlp_head(lg_cls)
return sm_logits + lg_logits
Cross-Attention方法
- 取L-Batch的CLS token:
x
c
l
s
l
x^l_{cls}
xclsl,S-Batch的patch_token:
x
p
a
t
c
h
s
x^s_{patch}
xpatchs,经过计算可以得到最下面虚线框内的数据
x
′
l
x^{'l}
x′l
x ′ l = [ f l ( x c l s l ) ∣ ∣ x p a t c h s ] x^{'l}=[f^l(x^l_{cls})||x^s_{patch}] x′l=[fl(xclsl)∣∣xpatchs] - 用
x
c
l
s
′
l
x^{'l}_{cls}
xcls′l计算Q,用
x
′
l
x^{'l}
x′l计算K和V
Q = x c l s ′ l W q K = x ′ l W k V = x ′ l W v Q=x^{'l}_{cls}W_q\qquad K=x^{'l}W_k\qquad V=x^{'l}W_v Q=xcls′lWqK=x′lWkV=x′lWv - 计算Cross-Attention值
C
A
(
x
′
l
)
\mathrm{CA}(x^{'l})
CA(x′l)
C A ( x ′ l ) = S o f t m a x ( Q K T / C / h ) V \mathrm{CA}(x^{'l})=\mathrm{Softmax}(QK^T/\sqrt{C/h})V CA(x′l)=Softmax(QKT/C/h)V - 计算新的CLS token:
y
c
l
s
l
y^l_{cls}
yclsl
y c l s l = g l ( f l ( x c l s l ) + C A ( x ′ l ) ) y^l_{cls}=g^l(f^l(x^l_{cls})+\mathrm{CA}(x^{'l})) yclsl=gl(fl(xclsl)+CA(x′l)) - 拼接得到新的L-Batch
z l = [ y c l s l ∣ ∣ x p a t c h l ] z^l=[y^l_{cls}||x^l_{patch}] zl=[yclsl∣∣xpatchl]
f l ( ⋅ ) f^l(\cdot) fl(⋅)和 g l ( ⋅ ) g^l(\cdot) gl(⋅)都是针对L-Batch的线性变化
代码实现(Cross-Attention部分)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, context = None, kv_include_self = False):
b, n, _, h = *x.shape, self.heads
context = default(context, x)
if kv_include_self:
context = torch.cat((x, context), dim = 1) # cross attention requires CLS token includes itself as key / value
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# projecting CLS tokens, in the case that small and large patch tokens have different dimensions
class ProjectInOut(nn.Module):
def __init__(self, dim_in, dim_out, fn):
super().__init__()
self.fn = fn
need_projection = dim_in != dim_out
self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity()
self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity()
def forward(self, x, *args, **kwargs):
x = self.project_in(x)
x = self.fn(x, *args, **kwargs)
x = self.project_out(x)
return x
# cross attention transformer
class CrossTransformer(nn.Module):
def __init__(self, sm_dim, lg_dim, depth, heads, dim_head, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
ProjectInOut(sm_dim, lg_dim, PreNorm(lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout))),
ProjectInOut(lg_dim, sm_dim, PreNorm(sm_dim, Attention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout)))
]))
def forward(self, sm_tokens, lg_tokens):
(sm_cls, sm_patch_tokens), (lg_cls, lg_patch_tokens) = map(lambda t: (t[:, :1], t[:, 1:]), (sm_tokens, lg_tokens))
for sm_attend_lg, lg_attend_sm in self.layers:
sm_cls = sm_attend_lg(sm_cls, context = lg_patch_tokens, kv_include_self = True) + sm_cls
lg_cls = lg_attend_sm(lg_cls, context = sm_patch_tokens, kv_include_self = True) + lg_cls
sm_tokens = torch.cat((sm_cls, sm_patch_tokens), dim = 1)
lg_tokens = torch.cat((lg_cls, lg_patch_tokens), dim = 1)
return sm_tokens, lg_tokens
PiT(加Pooling层)
整体结构
- 先进行embedding,得到基础的tokens
- 经过若干次Transformer层进行计算
- 进入pooling层,patch_tokens进行下采样,cls_tokens通过全连接层使维度与patch_tokens的契合
- 重复步骤2,重复步骤3,重复步骤2
- 使用cls_tokens输出结果
Pooling Layer
- 由于pooling是针对二维Tensor的操作,所以先将patch_token改变成二维的
- 使用DepthWise操作进行下采样,减少维度的同时提高深度(这里我觉得可能是论文作者写错了)
DepthWise操作(Xception的结构)
假设需要进行操作的数据维度为 w × h × d w\times h\times d w×h×d
- 先用 d d d个 3 × 3 3\times 3 3×3的卷积核,每个卷积核对应一个通道,步长为2,得到维度为 w 2 × h 2 × d \frac{w}{2}\times \frac{h}{2}\times d 2w×2h×d的数据
- 使用 2 d 2d 2d个 1 × 1 × d 1\times 1\times d 1×1×d的卷积核,分别将得到的各通道数据加权求和,得到维度为 w 2 × h 2 × 2 d \frac{w}{2}\times \frac{h}{2}\times 2d 2w×2h×2d的数据
DepthWise操作(代码里的结构)
假设需要进行操作的数据维度为 w × h × d w\times h\times d w×h×d
- 先用 2 d 2d 2d个 1 × 3 × 3 1 \times 3\times 3 1×3×3的卷积核,每2个卷积核对应一个通道,步长为2,得到维度为 w 2 × h 2 × 2 d \frac{w}{2}\times \frac{h}{2}\times 2d 2w×2h×2d的数据
- 使用 2 d 2d 2d个 1 × 1 × 2 d 1\times 1\times 2d 1×1×2d的卷积核,分别将得到的各通道数据加权求和,得到维度为 w 2 × h 2 × 2 d \frac{w}{2}\times \frac{h}{2}\times 2d 2w×2h×2d的数据
- 将patch_token恢复成一维的以进行后面Transformer层的计算
- cls_token直接使用一个全连接层提高深度
代码实现
# depthwise convolution, for pooling
class DepthWiseConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
nn.Conv2d(dim_out, dim_out, kernel_size = 1, bias = bias)
# Xception的结构:
# nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
# nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
)
def forward(self, x):
return self.net(x)
# pooling layer
class Pool(nn.Module):
def __init__(self, dim):
super().__init__()
self.downsample = DepthWiseConv2d(dim, dim * 2, kernel_size = 3, stride = 2, padding = 1)
self.cls_ff = nn.Linear(dim, dim * 2)
def forward(self, x):
cls_token, tokens = x[:, :1], x[:, 1:]
cls_token = self.cls_ff(cls_token)
tokens = rearrange(tokens, 'b (h w) c -> b c h w', h = int(sqrt(tokens.shape[1])))
tokens = self.downsample(tokens)
tokens = rearrange(tokens, 'b c h w -> b (h w) c')
return torch.cat((cls_token, tokens), dim = 1)
# main class
class PiT(nn.Module):
def __init__(
self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, dim_head = 64, dropout = 0., emb_dropout = 0.
):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
assert isinstance(depth, tuple), 'depth must be a tuple of integers, specifying the number of blocks before each downsizing'
heads = cast_tuple(heads, len(depth))
patch_dim = 3 * patch_size ** 2
self.to_patch_embedding = nn.Sequential(
nn.Unfold(kernel_size = patch_size, stride = patch_size // 2),
Rearrange('b c n -> b n c'),
nn.Linear(patch_dim, dim)
)
output_size = conv_output_size(image_size, patch_size, patch_size // 2)
num_patches = output_size ** 2
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
layers = []
for ind, (layer_depth, layer_heads) in enumerate(zip(depth, heads)):
not_last = ind < (len(depth) - 1)
layers.append(Transformer(dim, layer_depth, layer_heads, dim_head, mlp_dim, dropout))
if not_last:
layers.append(Pool(dim))
dim *= 2
self.layers = nn.Sequential(*layers)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x = self.dropout(x)
x = self.layers(x)
return self.mlp_head(x[:, 0])
LeViT
整体结构
我们可以清晰地看到,它分为了3个大块:
- 卷积部分:这部分的作用在于提取原图中的特征并且降低单张图片的维度(文章验证了,在进行Transformer之前加入小的卷积网络可以提高准确率)
- Transformer部分:这部分被分为了3个stage。可以看出,当不进行down sampling时,有残差网络;进行down sampling时,由于形状发生了变化,故没有残差结构。我们之后会着重分析这一部分
- Classifier部分:在LeViT中,我们没有使用cls_token结构,而是简单的将最后的结果进行均值池化。之后如果有蒸馏操作,那么将池化后的结果分别代入到学生分类器和老师分类器中进行分类(在代码中是用两个torch.Linear()实现的);如果没有蒸馏操作则直接进行分类。
- TIPS: 每次卷积后面都有初始权重为0的BatchNorm层
Attention Block
左边的是正常部分,右边的是Shrink Attention
D D D是Q、K的维度,在代码中被设置成了32
不进行down sampling时,V的维度为 2 D 2D 2D,进行down sampling时,V的维度为 4 D 4D 4D
N N N是head的数量,不进行down sampling时 N = C / 2 D N=C/2D N=C/2D;如果进行down sampling, N = C / D N=C/D N=C/D
不进行down sampling(正常的)
假设输入尺寸为 C × H × W C\times H\times W C×H×W
- 计算
Q
、
K
、
V
Q、K、V
Q、K、V
对原图进行point convolution,得到Q、K、V
Q . s h a p e = N D × H × W Q\mathrm{.shape}=ND\times H\times W Q.shape=ND×H×W
K . s h a p e = N D × H × W K\mathrm{.shape}=ND\times H\times W K.shape=ND×H×W
V . s h a p e = N 2 D × H × W V\mathrm{.shape}=N2D\times H\times W V.shape=N2D×H×W
然后将其整形
Q . s h a p e = N × H W × D Q\mathrm{.shape}=N\times HW\times D Q.shape=N×HW×D
K . s h a p e = N × H W × D K\mathrm{.shape}=N\times HW\times D K.shape=N×HW×D
V . s h a p e = N × H W × 2 D V\mathrm{.shape}=N\times HW\times 2D V.shape=N×HW×2D - 计算attention bias(
B
B
B)
用attention bias代替ViT中的pos_embedding,使得每一层都会代入位置信息(文中引用了他人的观点,论证加入位置信息后能提高Transformer的准确率)
pos_indices的尺寸为 H W × H W HW\times HW HW×HW,将其扩成 N × H W × H W N\times HW\times HW N×HW×HWself.pos_bias = nn.Embedding(fmap_size * fmap_size, heads) # fmap_size = 14 q_range = torch.arange(0, fmap_size, step = 1) k_range = torch.arange(fmap_size) q_pos = torch.stack(torch.meshgrid(q_range, q_range), dim = -1) k_pos = torch.stack(torch.meshgrid(k_range, k_range), dim = -1) # q_pos: torch.Size([14, 14, 2]) # k_pos: torch.Size([14, 14, 2]) q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos)) # rearranged q_pos: torch.Size([196, 2]) # rearranged k_pos: torch.Size([196, 2]) rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs() # rel_pos: torch.Size([196, 196, 2]) x_rel, y_rel = rel_pos.unbind(dim = -1) # x_rel: torch.Size([196, 196]) # y_rel: torch.Size([196, 196]) pos_indices = (x_rel * fmap_size) + y_rel # pos_indices: torch.Size([196, 196])
- 计算Attention weight(
A
A
A)
A = S o f t m a x ( ( Q K T + B ) / D A=\mathrm{Softmax}((QK^T+B)/\sqrt{D} A=Softmax((QKT+B)/D
A . s h a p e = N × H W × H W A\mathrm{.shape}=N\times HW\times HW A.shape=N×HW×HW - 计算Attention Value
A t t e n t i o n ( Q , K , V ) = A V \mathrm{Attention}(Q,K,V)=AV Attention(Q,K,V)=AV
A t t e n t i o n ( Q , K , V ) . s h a p e = N × H W × 2 D \mathrm{Attention}(Q,K,V)\mathrm{.shape}=N\times HW\times 2D Attention(Q,K,V).shape=N×HW×2D - 对输出进行处理
先进行整形,将其整型为 N 2 D × H × W N2D\times H\times W N2D×H×W
再使用Hardswish激活函数进行计算
最后使用point convolution将维度由 N 2 D × H × W N2D\times H\times W N2D×H×W改为 C × H × W C\times H\times W C×H×W
H a r d s w i s h ( x ) = { 0 x ≤ − 3 1 x ≥ + 3 x ⋅ ( x + 3 ) / 6 otherwise \mathrm{Hardswish}(x)= \begin{cases} 0& x\leq -3 \\ 1& x\geq +3\\ x\cdot(x+3)/6& \text{otherwise} \end{cases} Hardswish(x)=⎩⎪⎨⎪⎧01x⋅(x+3)/6x≤−3x≥+3otherwise
进行down sampling(Shrink Attention)
假设输入尺寸为 C × H × W C\times H\times W C×H×W
- 计算
Q
、
K
、
V
Q、K、V
Q、K、V
对原图进行point convolution,得到Q、K、V;求Q时,步长设为2
Q . s h a p e = N D × H 2 × W 2 Q\mathrm{.shape}=ND\times \frac{H}{2}\times \frac{W}{2} Q.shape=ND×2H×2W
K . s h a p e = N D × H × W K\mathrm{.shape}=ND\times H\times W K.shape=ND×H×W
V . s h a p e = N 4 D × H × W V\mathrm{.shape}=N4D\times H\times W V.shape=N4D×H×W
然后将其整形
Q . s h a p e = N × H W 4 × D Q\mathrm{.shape}=N\times \frac{HW}{4}\times D Q.shape=N×4HW×D
K . s h a p e = N × H W × D K\mathrm{.shape}=N\times HW\times D K.shape=N×HW×D
V . s h a p e = N × H W × 4 D V\mathrm{.shape}=N\times HW\times 4D V.shape=N×HW×4D - 计算attention bias(
B
B
B)
用attention bias代替ViT中的pos_embedding,使得每一层都会代入位置信息(文中引用了他人的观点,论证加入位置信息后能提高Transformer的准确率)
pos_indices的尺寸为 H W 4 × H W \frac{HW}{4}\times HW 4HW×HW,将其扩成 N × H W 4 × H W N\times \frac{HW}{4}\times HW N×4HW×HWself.pos_bias = nn.Embedding(fmap_size * fmap_size, heads) # fmap_size = 14 q_range = torch.arange(0, fmap_size, step = 2) k_range = torch.arange(fmap_size) q_pos = torch.stack(torch.meshgrid(q_range, q_range), dim = -1) k_pos = torch.stack(torch.meshgrid(k_range, k_range), dim = -1) # q_pos: torch.Size([7, 7, 2]) # k_pos: torch.Size([14, 14, 2]) q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos)) # rearranged q_pos: torch.Size([49, 2]) # rearranged k_pos: torch.Size([196, 2]) rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs() # rel_pos: torch.Size([49, 196, 2]) x_rel, y_rel = rel_pos.unbind(dim = -1) # x_rel: torch.Size([49, 196]) # y_rel: torch.Size([49, 196]) pos_indices = (x_rel * fmap_size) + y_rel # pos_indices: torch.Size([49, 196])
- 计算Attention weight(
A
A
A)
A = S o f t m a x ( ( Q K T + B ) / D A=\mathrm{Softmax}((QK^T+B)/\sqrt{D} A=Softmax((QKT+B)/D
A . s h a p e = N × H W 4 × H W A\mathrm{.shape}=N\times \frac{HW}{4}\times HW A.shape=N×4HW×HW - 计算Attention Value
A t t e n t i o n ( Q , K , V ) = A V \mathrm{Attention}(Q,K,V)=AV Attention(Q,K,V)=AV
A t t e n t i o n ( Q , K , V ) . s h a p e = N × H W × 4 D \mathrm{Attention}(Q,K,V)\mathrm{.shape}=N\times HW\times 4D Attention(Q,K,V).shape=N×HW×4D - 对输出进行处理
先进行整形,将其整型为 N 4 D × H 2 × W 2 N4D\times \frac{H}{2}\times \frac{W}{2} N4D×2H×2W
再使用Hardswish激活函数进行计算
最后使用point convolution将维度由 N 4 D × H 2 × W 2 N4D\times \frac{H}{2}\times \frac{W}{2} N4D×2H×2W改为 C ′ × H 2 × W 2 C'\times \frac{H}{2}\times \frac{W}{2} C′×2H×2W
H a r d s w i s h ( x ) = { 0 x ≤ − 3 1 x ≥ + 3 x ⋅ ( x + 3 ) / 6 otherwise \mathrm{Hardswish}(x)= \begin{cases} 0& x\leq -3 \\ 1& x\geq +3\\ x\cdot(x+3)/6& \text{otherwise} \end{cases} Hardswish(x)=⎩⎪⎨⎪⎧01x⋅(x+3)/6x≤−3x≥+3otherwise
MLP
MLP是比较简单的2层point convolution的组合
class FeedForward(nn.Module):
def __init__(self, dim, mult, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(dim * mult, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
代码实现
class Attention(nn.Module):
def __init__(self, dim, fmap_size, heads = 8, dim_key = 32, dim_value = 64, dropout = 0., dim_out = None, downsample = False):
super().__init__()
inner_dim_key = dim_key * heads
inner_dim_value = dim_value * heads
dim_out = default(dim_out, dim)
self.heads = heads
self.scale = dim_key ** -0.5
self.to_q = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, stride = (2 if downsample else 1), bias = False), nn.BatchNorm2d(inner_dim_key))
self.to_k = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, bias = False), nn.BatchNorm2d(inner_dim_key))
self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value))
self.attend = nn.Softmax(dim = -1)
out_batch_norm = nn.BatchNorm2d(dim_out)
nn.init.zeros_(out_batch_norm.weight)
self.to_out = nn.Sequential(
nn.GELU(),
nn.Conv2d(inner_dim_value, dim_out, 1),
out_batch_norm,
nn.Dropout(dropout)
)
# positional bias
self.pos_bias = nn.Embedding(fmap_size * fmap_size, heads)
q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1))
k_range = torch.arange(fmap_size)
q_pos = torch.stack(torch.meshgrid(q_range, q_range), dim = -1)
k_pos = torch.stack(torch.meshgrid(k_range, k_range), dim = -1)
q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()
x_rel, y_rel = rel_pos.unbind(dim = -1)
pos_indices = (x_rel * fmap_size) + y_rel
self.register_buffer('pos_indices', pos_indices)
def apply_pos_bias(self, fmap):
bias = self.pos_bias(self.pos_indices)
bias = rearrange(bias, 'i j h -> () h i j')
return fmap + (bias / self.scale)
def forward(self, x):
b, n, *_, h = *x.shape, self.heads
q = self.to_q(x)
y = q.shape[2]
qkv = (q, self.to_k(x), self.to_v(x))
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
dots = self.apply_pos_bias(dots)
attn = self.attend(dots)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y)
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult = 2, dropout = 0., dim_out = None, downsample = False):
super().__init__()
dim_out = default(dim_out, dim)
self.layers = nn.ModuleList([])
self.attn_residual = (not downsample) and dim == dim_out
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, fmap_size = fmap_size, heads = heads, dim_key = dim_key, dim_value = dim_value, dropout = dropout, downsample = downsample, dim_out = dim_out),
FeedForward(dim_out, mlp_mult, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
attn_res = (x if self.attn_residual else 0)
x = attn(x) + attn_res
x = ff(x) + x
return x
class LeViT(nn.Module):
def __init__(
self,
*,
image_size,
num_classes,
dim,
depth,
heads,
mlp_mult,
stages = 3,
dim_key = 32,
dim_value = 64,
dropout = 0.,
num_distill_classes = None
):
super().__init__()
dims = cast_tuple(dim, stages)
depths = cast_tuple(depth, stages)
layer_heads = cast_tuple(heads, stages)
assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'
self.conv_embedding = nn.Sequential(
nn.Conv2d(3, 32, 3, stride = 2, padding = 1),
nn.Conv2d(32, 64, 3, stride = 2, padding = 1),
nn.Conv2d(64, 128, 3, stride = 2, padding = 1),
nn.Conv2d(128, dims[0], 3, stride = 2, padding = 1)
)
fmap_size = image_size // (2 ** 4)
layers = []
for ind, dim, depth, heads in zip(range(stages), dims, depths, layer_heads):
is_last = ind == (stages - 1)
layers.append(Transformer(dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult, dropout))
if not is_last:
next_dim = dims[ind + 1]
layers.append(Transformer(dim, fmap_size, 1, heads * 2, dim_key, dim_value, dim_out = next_dim, downsample = True))
fmap_size = ceil(fmap_size / 2)
self.backbone = nn.Sequential(*layers)
self.pool = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
Rearrange('... () () -> ...')
)
self.distill_head = nn.Linear(dim, num_distill_classes) if exists(num_distill_classes) else always(None)
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, img):
x = self.conv_embedding(img)
x = self.backbone(x)
x = self.pool(x)
out = self.mlp_head(x)
distill = self.distill_head(x)
if exists(distill):
return out, distill
return out
CvT(卷积+Transformer)
整体架构
注意: 在论文的图片里,我们能看到在stage3阶段我们加入了cls_token,并在最后使用的是cls_token进行的分类。但是在源码中,作者并不是这么实现的,而是使用自适应均值平均(AdaptiveAvePool2d) 函数的输出作为分类器的输入。
整体步骤可以分为3个stage
一个stage内有2个模块,分别被称作Convolutional Token Embedding和Convolutional Transformer Blocks
详细模块
Convolutional Token Embedding
这个模块就是一个卷积层,目的是提取特征信息,降低图像大小
执行完卷积层之后会进行LayerNorm
Convolutional Transformer Blocks
这个模块与ViT的Transformer部分不同,具体体现在Q、K、V的生成方式
在ViT中,我们先将2D的信息重组为序列信息,然后使用线性映射得到Q、K、V
x
i
q
/
k
/
v
=
L
i
n
e
a
r
(
x
i
)
x_i^{q/k/v}=\mathrm{Linear}(x_i)
xiq/k/v=Linear(xi)
在CvT中,我们直接使用2D的信息,使用卷积操作得到Q、K、V的2D表达,然后再拉伸成序列信息的形式。
x
i
q
/
k
/
v
=
F
l
a
t
t
e
n
(
C
o
n
v
2
d
(
R
e
s
h
a
p
e
2
D
(
x
i
)
,
s
)
)
x_i^{q/k/v}=\mathrm{Flatten}(\mathrm{Conv2d}(\mathrm{Reshape2D}(x_i),s))
xiq/k/v=Flatten(Conv2d(Reshape2D(xi),s))
计算完成后,我们还会将序列信息重组为2D信息,以进行接下来的操作。
在阅读源码时,我们发现计算Attention Value时,我们并没有加上位置信息,也就是ViT中的pos_embedding,论文中也给了解释:
因为进行卷积映射时,我们在无形中也将各个元素与周围元素之间的联系保留了下来,因此pos_embedding也就没那么重要了。
代码实现
class DepthWiseConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
nn.BatchNorm2d(dim_in),
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
)
def forward(self, x):
return self.net(x)
这个DepthWiseConv2d函数是Xception结构,可以与PiT的DepthWiseConv2d函数进行比较
class Attention(nn.Module):
def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
padding = proj_kernel // 2
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
# 卷积映射
self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
shape = x.shape
b, n, _, y, h = *shape, self.heads
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = self.attend(dots)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
return self.to_out(out)
class CvT(nn.Module):
def __init__(
self,
*,
num_classes,
s1_emb_dim = 64, s1_emb_kernel = 7, s1_emb_stride = 4, s1_proj_kernel = 3, s1_kv_proj_stride = 2, s1_heads = 1, s1_depth = 1, s1_mlp_mult = 4,
s2_emb_dim = 192, s2_emb_kernel = 3, s2_emb_stride = 2, s2_proj_kernel = 3, s2_kv_proj_stride = 2, s2_heads = 3, s2_depth = 2, s2_mlp_mult = 4,
s3_emb_dim = 384, s3_emb_kernel = 3, s3_emb_stride = 2, s3_proj_kernel = 3, s3_kv_proj_stride = 2, s3_heads = 6, s3_depth = 10, s3_mlp_mult = 4,
dropout = 0.
):
super().__init__()
kwargs = dict(locals())
dim = 3
layers = []
# 每个Stage内的部分
for prefix in ('s1', 's2', 's3'):
config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs)
layers.append(nn.Sequential(
nn.Conv2d(dim, config['emb_dim'], kernel_size = config['emb_kernel'], padding = (config['emb_kernel'] // 2), stride = config['emb_stride']),
LayerNorm(config['emb_dim']),
Transformer(dim = config['emb_dim'], proj_kernel = config['proj_kernel'], kv_proj_stride = config['kv_proj_stride'], depth = config['depth'], heads = config['heads'], mlp_mult = config['mlp_mult'], dropout = dropout)
))
dim = config['emb_dim']
self.layers = nn.Sequential(
*layers,
# 最后经过自适应均值平均池化而非使用cls_token进行分类
nn.AdaptiveAvgPool2d(1),
Rearrange('... () () -> ...'),
nn.Linear(dim, num_classes)
)
def forward(self, x):
return self.layers(x)