Masked Auto Encoder总结

Masked Auto Encoder总结

MAE简介

MAE是用于CV的自监督学习方法,优点是扩展性强的,方法简单。在MAE方法中会随机mask输入图片的部分patches,然后重构这些缺失的像素。

MAE结构简图

MAE基于两个核心设计:

  1. 不对称的(asymmetric)编码解码结构,编码器仅仅对可见的patches进行编码,不对mask tokens进行任何处理,解码器将编码器的输出(latent representation)和 mask tokens 作为输入,重构image。
  2. 使用较高的mask比例(eg:75%)进行训练时,MAE展现了很强的迁移性能,且因为方法简单,可扩展性极强。

通过上图可以看到,我们首先将一张大图片切割为很多小的patches,然后随机mask了很多patch;再将没有被mask掉的图片送入encoder提取特征,之后将特征和被mask掉的patch(未经过任何处理)按照顺序拼接在一起,通过decoder进行图片重建。

MAE的encoder主要是基于ViT的,我们在讲解ViT的时候详细的说了ViT的结构(如果对ViT不了解,可以看我上一篇blog),所以MAE的encoder对于我们来讲就显得很容易理解。我们会对encoder做简要的说明,将主要篇幅放在MAE如何进行随机mask、如何还原原来的顺序、decoder的结构以及对于还原图片loss计算上。

Random Mask

首先我们进行MAE中第一步的解释,我们如何将图片随机mask。

进行mask之前还需要将图片变成patch,此时维度变化如下:
[ b , c , h , w ] − > [ b , ( h / p a t c h h ) ∗ ( w / p a t c h w ) , ( p a t c h h ∗ p a t c h w ∗ c ) ] [b,c,h,w]->[b,(h/patch_h)*(w/patch_w),(patch_h*patch_w*c)] [b,c,h,w]>[b,(h/patchh)(w/patchw),(patchhpatchwc)]
这部分代码中直接调用了Ross Wightman在github上的开源项目中的源码 源码链接,图片patch化的代码如下:

from timm.models.vision_transformer import PatchEmbed, Block
def forward_encoder(self, x, mask_ratio):
        # embed patches
        x = self.patch_embed(x)

        # add pos embed w/o cls token
        # 这里是加上了位置编码,这部分在ViT文章中解释过,在这里不赘述
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

获得patch化的x之后,接下来就是随即掩码的过程。

random mask 逻辑

首先我们需要清楚,对于输入x,他的维度是 [ b , p a t c h _ n u m , d i m ] [b,patch\_num,dim] [b,patch_num,dim],我们掩码x,实际的掩码对象是dim维度的数据。由于一个batch中有 p a t c h _ n u m patch\_num patch_num个dim维数据,所以我们要掩码 p a t c h _ n u m ∗ 75 patch\_num * 75% patch_num75 这么多的数据。

为了达到这一目的,我们首先随机化一个 [ b , p a t c h _ n u m ] [b,patch\_num] [b,patch_num]维的向量,然后对其第一维(patch_num)进行排序,并且对处于前25%的部分对应的dim维数据保持原样,剩下的进行掩码。这样就完成了随机掩码75%的操作。

这里要注意,因为其用到的是argsort()函数,所以返回的是排序的下标。由于在模型中我们需要记住打乱的数据原来的位置,此信息在这段代码中由ids_shuffleids_restore两个变量来保存。

保存的逻辑如下:

打乱和还原的逻辑

我们可以发现,ids_shuffleids_restore是的下标和值存在明显的对应关系,可以通过这两个变量记录之前patch的实际位置。

之后从ids_shuffle中前提取出len_keep长度的信息保留,并且使用torch.gather()函数将被保留的信息挑出来。这里torch.gather()函数是一个巧妙的函数,具体的介绍可以通过这个链接查看 gather函数介绍。这里我们只需要知道,它可以按照我们保留的num_patches下标,将需要的dim维向量整合出来,并保存在x_masked变量中。

同样的,代码最后的mask变量(维度为 [ b , p a t c h _ n u m ] [b,patch\_num] [b,patch_num])也变得好理解了,我们用一个batch来举例,其中的变量是如下形式的: [ 1 , 0 , 1 , 1 , 1 , . . . 1 , 0 , 1 ] [1,0,1,1,1,...1,0,1] [1,0,1,1,1,...1,0,1]。其中1表示这个位置对应的dim维patch被mask了,反之没有。

random mask 实现

    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))
        
        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]
        
        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore
    

Encoder

前文说过,MAE中的Encoder是一个ViT,所以这里我们直接从代码中解读Encoder。

Encoder网络结构

从Encoder的网络结构可以看出,他基本就是一个ViT,我们可以看到非常熟悉的cls_token与pos_embed变量。

但是我们还注意到了,这个ViT的Transformer是通过Block实现的。这个Block同样也是来自Ross Wightman的代码,Block的源代码在后面贴出。我们可以看到Block是一个标准的多头注意力模型。

class MaskedAutoencoderViT(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """
    def __init__(self, 
                 img_size=224, 
                 patch_size=16, 
                 in_chans=3,
                 embed_dim=1024, 		 # encoder的隐藏层维度
                 depth=24,				# encoder中transformer的深度 
                 num_heads=16,
                 decoder_embed_dim=512,	
                 decoder_depth=8, 
                 decoder_num_heads=16,
                 mlp_ratio=4., 
                 norm_layer=nn.LayerNorm, 
                 norm_pix_loss=False):
        super().__init__()

        # --------------------------------------------------------------------------
        # MAE encoder specifics
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)  # fixed sin-cos embedding

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # --------------------------------------------------------------------------

Block源代码

可以看出这是一个标准的多头注意力模型。

class Block(nn.Module):

    def __init__(
            self,
            dim,
            num_heads,
            mlp_ratio=4.,
            qkv_bias=False,
            drop=0.,
            attn_drop=0.,
            init_values=None,
            drop_path=0.,
            act_layer=nn.GELU,
            norm_layer=nn.LayerNorm
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x

Ecoder计算流程

可以看出Encoder的计算流程,相比ViT,只是多出了一个random_masking

最终x的维度为 [ b , k e e p _ l e n g t h + 1 , d i m ] [b,keep\_length + 1,dim] [b,keep_length+1,dim],之所以第二维+1是因为拼接上了 c l s _ t o k e n cls\_token cls_token

    def forward_encoder(self, x, mask_ratio):
        # embed patches
        x = self.patch_embed(x)

        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        return x, mask, ids_restore

Decoder

Decoder网络结构

由于encoder输出和decoder输入的维度不同,所以decoder先通过一个线性层映射,将维度转换成decoder的输入维度。

这里的mask_token是个重要的部分,暂时他的维度是 [ 1 , 1 , d e c o d e r _ e m b e d _ d i m ] [1,1,decoder\_embed\_dim] [1,1,decoder_embed_dim],他表示的是一张图片中被掩码的patch。在之后的forward函数中,mask_token会被扩增为 [ b , p a t c h _ n u m ∗ 75 % , d e c o d e r _ e m b e d _ d i m ] [b,patch\_num * 75\%,decoder\_embed\_dim] [b,patch_num75%,decoder_embed_dim],表示所有batch中被掩码掉的patch。

之后就是我们熟悉的位置编码、transformer层、norm层以及维度转换的Linear层。

class MaskedAutoencoderViT(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 embed_dim=1024, depth=24, num_heads=16,
                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
        super().__init__()

        # --------------------------------------------------------------------------
        # MAE encoder specifics
        # 这里是原来的encoder代码
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # MAE decoder specifics
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False)  # fixed sin-cos embedding

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
            for i in range(decoder_depth)])

        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
        # --------------------------------------------------------------------------

Decoder计算流程

我们需要说明,decoder的输入是encoder的输出。所以我们首先需要线性层进行维度转换,下一步是很重要的一步——将被mask掉的值拼接回去,这一步的操作会结合代码详细讲解。

首先我们将 Decoder网络结构 中定义的mask_token进行repeat操作,使其最后的数量与被mask掉的块数量一致。具体讲解在代码中以注释的形式出现。

def forward_decoder(self, x, ids_restore):
        # embed tokens
        x = self.decoder_embed(x)

        # append mask tokens to sequence
        '''
        这里mask_tokens最后的维度是[b,patch_num*75%,decoder_embed_dim]、
        第一维是batchsize是很好理解的,这就是x.shape[0]的含义
        
        第二维有点复杂,首先明确ids_restore.shape[1]是patch_num,x.shape[1]是keep_length + 1,这个1需要提醒,是在x上面加入的cls_token,所以 ids_restore.shape[1] + 1 - x.shape[1] 就是patch_num + 1 - (keep_length + 1) = patch_num - keep_length。而keep_length就是patch_num * 25%的结果(这一点如果忘掉的话可以看Random Mask部分回忆一下),所以第二维就是被mask部分的长度。
        
       	第三维很好理解,在repeat函数中参数为1,也就是保留之前的 decoder_embed_dim 维度
        '''
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        
        '''
        这里是将未被mask的x和刚刚构建出来的被mask部分拼接起来,但是在这里还没有恢复他们的正确位置,只是把mask_tokens简单的拼接在了x后面,并且此处的x是去掉了cls_token的
        '''
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
        
        '''
        这里是利用gather函数将上一条代码的简单拼接改为了正确的patch位置。
        '''
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
        
        '''
        这里将输入x的cls_token又重新拼接,因为我们并没有改变batch维度的顺序(原来在位置i的图片信息现在还在位置i),所以将参与了训练的cls_token直接拼接是不会有问题的
        '''
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token
		
        # add pos embed
        x = x + self.decoder_pos_embed

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        # predictor projection
        x = self.decoder_pred(x)

        # remove cls token
        x = x[:, 1:, :]

        return x

Loss计算

图片重建部分的loss并没有调用pytorch提供的函数,但并不意味着这部分困难。相反,这里loss的计算逻辑很简单,只是计算被mask掉部分的预测值与实际值之间的欧氏距离。

Loss计算流程

    def forward_loss(self, imgs, pred, mask):
        """
        imgs: [N, 3, H, W]
        pred: [N, L, p*p*3]
        mask: [N, L], 0 is keep, 1 is remove, 
        """
        # 将原始图片patch化
        target = self.patchify(imgs)
        # 将target归一化
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5
        
        # 计算每一个patch的loss 
        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch
        
        # 为了计算实际预测的误差,应该只计算被mask的部分的loss,这时候mask矩阵就派上用场了,他将没有被mask的部分乘以0,除去了这部分loss的影响,只保留了mask部分的loss
        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
        return loss

atch的loss
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch

    # 为了计算实际预测的误差,应该只计算被mask的部分的loss,这时候mask矩阵就派上用场了,他将没有被mask的部分乘以0,除去了这部分loss的影响,只保留了mask部分的loss
    loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
    return loss



  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值