Masked Auto Encoder总结
文章目录
MAE简介
MAE是用于CV的自监督学习方法,优点是扩展性强的,方法简单。在MAE方法中会随机mask输入图片的部分patches,然后重构这些缺失的像素。
MAE基于两个核心设计:
- 不对称的(asymmetric)编码解码结构,编码器仅仅对可见的patches进行编码,不对mask tokens进行任何处理,解码器将编码器的输出(latent representation)和 mask tokens 作为输入,重构image。
- 使用较高的mask比例(eg:75%)进行训练时,MAE展现了很强的迁移性能,且因为方法简单,可扩展性极强。
通过上图可以看到,我们首先将一张大图片切割为很多小的patches,然后随机mask了很多patch;再将没有被mask掉的图片送入encoder提取特征,之后将特征和被mask掉的patch(未经过任何处理)按照顺序拼接在一起,通过decoder进行图片重建。
MAE的encoder主要是基于ViT的,我们在讲解ViT的时候详细的说了ViT的结构(如果对ViT不了解,可以看我上一篇blog),所以MAE的encoder对于我们来讲就显得很容易理解。我们会对encoder做简要的说明,将主要篇幅放在MAE如何进行随机mask、如何还原原来的顺序、decoder的结构以及对于还原图片loss计算上。
Random Mask
首先我们进行MAE中第一步的解释,我们如何将图片随机mask。
进行mask之前还需要将图片变成patch,此时维度变化如下:
[
b
,
c
,
h
,
w
]
−
>
[
b
,
(
h
/
p
a
t
c
h
h
)
∗
(
w
/
p
a
t
c
h
w
)
,
(
p
a
t
c
h
h
∗
p
a
t
c
h
w
∗
c
)
]
[b,c,h,w]->[b,(h/patch_h)*(w/patch_w),(patch_h*patch_w*c)]
[b,c,h,w]−>[b,(h/patchh)∗(w/patchw),(patchh∗patchw∗c)]
这部分代码中直接调用了Ross Wightman在github上的开源项目中的源码 源码链接,图片patch化的代码如下:
from timm.models.vision_transformer import PatchEmbed, Block
def forward_encoder(self, x, mask_ratio):
# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
# 这里是加上了位置编码,这部分在ViT文章中解释过,在这里不赘述
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
x, mask, ids_restore = self.random_masking(x, mask_ratio)
获得patch化的x之后,接下来就是随即掩码的过程。
random mask 逻辑
首先我们需要清楚,对于输入x,他的维度是 [ b , p a t c h _ n u m , d i m ] [b,patch\_num,dim] [b,patch_num,dim],我们掩码x,实际的掩码对象是dim维度的数据。由于一个batch中有 p a t c h _ n u m patch\_num patch_num个dim维数据,所以我们要掩码 p a t c h _ n u m ∗ 75 patch\_num * 75% patch_num∗75 这么多的数据。
为了达到这一目的,我们首先随机化一个 [ b , p a t c h _ n u m ] [b,patch\_num] [b,patch_num]维的向量,然后对其第一维(patch_num)进行排序,并且对处于前25%的部分对应的dim维数据保持原样,剩下的进行掩码。这样就完成了随机掩码75%的操作。
这里要注意,因为其用到的是argsort()
函数,所以返回的是排序的下标。由于在模型中我们需要记住打乱的数据原来的位置,此信息在这段代码中由ids_shuffle
和ids_restore
两个变量来保存。
保存的逻辑如下:
我们可以发现,ids_shuffle
和ids_restore
是的下标和值存在明显的对应关系,可以通过这两个变量记录之前patch的实际位置。
之后从ids_shuffle
中前提取出len_keep
长度的信息保留,并且使用torch.gather()
函数将被保留的信息挑出来。这里torch.gather()
函数是一个巧妙的函数,具体的介绍可以通过这个链接查看 gather函数介绍。这里我们只需要知道,它可以按照我们保留的num_patches
下标,将需要的dim
维向量整合出来,并保存在x_masked
变量中。
同样的,代码最后的mask
变量(维度为
[
b
,
p
a
t
c
h
_
n
u
m
]
[b,patch\_num]
[b,patch_num])也变得好理解了,我们用一个batch
来举例,其中的变量是如下形式的:
[
1
,
0
,
1
,
1
,
1
,
.
.
.
1
,
0
,
1
]
[1,0,1,1,1,...1,0,1]
[1,0,1,1,1,...1,0,1]。其中1表示这个位置对应的dim
维patch被mask了,反之没有。
random mask 实现
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
Encoder
前文说过,MAE中的Encoder是一个ViT,所以这里我们直接从代码中解读Encoder。
Encoder网络结构
从Encoder的网络结构可以看出,他基本就是一个ViT,我们可以看到非常熟悉的cls_token与pos_embed变量。
但是我们还注意到了,这个ViT的Transformer是通过Block实现的。这个Block同样也是来自Ross Wightman的代码,Block的源代码在后面贴出。我们可以看到Block是一个标准的多头注意力模型。
class MaskedAutoencoderViT(nn.Module):
""" Masked Autoencoder with VisionTransformer backbone
"""
def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=1024, # encoder的隐藏层维度
depth=24, # encoder中transformer的深度
num_heads=16,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=16,
mlp_ratio=4.,
norm_layer=nn.LayerNorm,
norm_pix_loss=False):
super().__init__()
# --------------------------------------------------------------------------
# MAE encoder specifics
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
self.blocks = nn.ModuleList([
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# --------------------------------------------------------------------------
Block源代码
可以看出这是一个标准的多头注意力模型。
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
Ecoder计算流程
可以看出Encoder的计算流程,相比ViT,只是多出了一个random_masking
。
最终x的维度为 [ b , k e e p _ l e n g t h + 1 , d i m ] [b,keep\_length + 1,dim] [b,keep_length+1,dim],之所以第二维+1是因为拼接上了 c l s _ t o k e n cls\_token cls_token。
def forward_encoder(self, x, mask_ratio):
# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
x, mask, ids_restore = self.random_masking(x, mask_ratio)
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, mask, ids_restore
Decoder
Decoder网络结构
由于encoder输出和decoder输入的维度不同,所以decoder先通过一个线性层映射,将维度转换成decoder的输入维度。
这里的mask_token是个重要的部分,暂时他的维度是
[
1
,
1
,
d
e
c
o
d
e
r
_
e
m
b
e
d
_
d
i
m
]
[1,1,decoder\_embed\_dim]
[1,1,decoder_embed_dim],他表示的是一张图片中被掩码的patch。在之后的forward
函数中,mask_token会被扩增为
[
b
,
p
a
t
c
h
_
n
u
m
∗
75
%
,
d
e
c
o
d
e
r
_
e
m
b
e
d
_
d
i
m
]
[b,patch\_num * 75\%,decoder\_embed\_dim]
[b,patch_num∗75%,decoder_embed_dim],表示所有batch中被掩码掉的patch。
之后就是我们熟悉的位置编码、transformer层、norm层以及维度转换的Linear层。
class MaskedAutoencoderViT(nn.Module):
""" Masked Autoencoder with VisionTransformer backbone
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3,
embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
super().__init__()
# --------------------------------------------------------------------------
# MAE encoder specifics
# 这里是原来的encoder代码
# --------------------------------------------------------------------------
# --------------------------------------------------------------------------
# MAE decoder specifics
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
self.decoder_blocks = nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
for i in range(decoder_depth)])
self.decoder_norm = norm_layer(decoder_embed_dim)
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
# --------------------------------------------------------------------------
Decoder计算流程
我们需要说明,decoder的输入是encoder的输出。所以我们首先需要线性层进行维度转换,下一步是很重要的一步——将被mask掉的值拼接回去,这一步的操作会结合代码详细讲解。
首先我们将 Decoder网络结构 中定义的mask_token进行repeat操作,使其最后的数量与被mask掉的块数量一致。具体讲解在代码中以注释的形式出现。
def forward_decoder(self, x, ids_restore):
# embed tokens
x = self.decoder_embed(x)
# append mask tokens to sequence
'''
这里mask_tokens最后的维度是[b,patch_num*75%,decoder_embed_dim]、
第一维是batchsize是很好理解的,这就是x.shape[0]的含义
第二维有点复杂,首先明确ids_restore.shape[1]是patch_num,x.shape[1]是keep_length + 1,这个1需要提醒,是在x上面加入的cls_token,所以 ids_restore.shape[1] + 1 - x.shape[1] 就是patch_num + 1 - (keep_length + 1) = patch_num - keep_length。而keep_length就是patch_num * 25%的结果(这一点如果忘掉的话可以看Random Mask部分回忆一下),所以第二维就是被mask部分的长度。
第三维很好理解,在repeat函数中参数为1,也就是保留之前的 decoder_embed_dim 维度
'''
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
'''
这里是将未被mask的x和刚刚构建出来的被mask部分拼接起来,但是在这里还没有恢复他们的正确位置,只是把mask_tokens简单的拼接在了x后面,并且此处的x是去掉了cls_token的
'''
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
'''
这里是利用gather函数将上一条代码的简单拼接改为了正确的patch位置。
'''
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
'''
这里将输入x的cls_token又重新拼接,因为我们并没有改变batch维度的顺序(原来在位置i的图片信息现在还在位置i),所以将参与了训练的cls_token直接拼接是不会有问题的
'''
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# add pos embed
x = x + self.decoder_pos_embed
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
x = self.decoder_pred(x)
# remove cls token
x = x[:, 1:, :]
return x
Loss计算
图片重建部分的loss
并没有调用pytorch
提供的函数,但并不意味着这部分困难。相反,这里loss
的计算逻辑很简单,只是计算被mask掉部分的预测值与实际值之间的欧氏距离。
Loss计算流程
def forward_loss(self, imgs, pred, mask):
"""
imgs: [N, 3, H, W]
pred: [N, L, p*p*3]
mask: [N, L], 0 is keep, 1 is remove,
"""
# 将原始图片patch化
target = self.patchify(imgs)
# 将target归一化
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6)**.5
# 计算每一个patch的loss
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
# 为了计算实际预测的误差,应该只计算被mask的部分的loss,这时候mask矩阵就派上用场了,他将没有被mask的部分乘以0,除去了这部分loss的影响,只保留了mask部分的loss
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss
atch的loss
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
# 为了计算实际预测的误差,应该只计算被mask的部分的loss,这时候mask矩阵就派上用场了,他将没有被mask的部分乘以0,除去了这部分loss的影响,只保留了mask部分的loss
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss