基础论文学习(5)——MAE

MAE:Masked Autoencoders Are Scalable Vision Learners

Self-Supervised Learning

  • step1:先用无标签数据集,把参数从一张白纸训练到初步预训练模型,可以得到数据的 Visual Representation
  • step2:再从初步成型,根据你下游任务 Downstream Tasks的不同去用带标签的数据集把参数训练到完全成型。注意这是2个阶段。

在这里插入图片描述

第一个阶段不涉及任何下游任务,就是拿着一堆无标签的数据去预训练,没有特定的任务,这个话用官方语言表达叫做:in a task-agnostic way

第二个阶段涉及下游任务,就是拿着一堆带标签的数据去在下游任务上 Fine-tune,这个话用官方语言表达叫做:in a task-specific way

Self-Supervised Learning 不仅是在NLP领域,在CV, 语音领域也有很多经典的工作,如下图2所示。它可以分成3类:Data Centric, Prediction (也叫 Generative)Contrastive
在这里插入图片描述
其中的主流就是基于 Generative 的方法和基于 Contrative 的方法。如下图所示这里简单介绍下。

  • 基于 Generative 的方法主要关注的重建误差,比如对于 NLP 任务而言,一个句子中间盖住一个 token,让模型去预测,令得到的预测结果与真实的 token 之间的误差作为损失。如Diffusion、VAE等。
  • 基于 Contrastive 的方法不要求模型能够重建原始输入,而是希望模型能够在特征空间上对不同的输入进行分辨。如SimCLR等

在这里插入图片描述

1. Masked AutoEncoders (MAE) 原理架构

掩码自编码器 (masked autoencoders (MAE)) 要做的事情还是通过自监督学习将被masked抹去的图像块补充上。属于 Generative (Predictive) pre-training 的类型。这种类型自监督学习的另一个著名的例子就是 BERT。

对于 BERT 模型而言,一个 sentence 中间盖住一些 tokens,让模型去预测,令得到的预测结果与真实的 tokens 之间的误差作为损失。它告诉了我们直接 reconstruct sentence 也可以做到很 work。

对于 MAE 模型而言,一个 image 中间盖住一些 patches,让模型去预测,令得到的预测结果与真实的 image patches 之间的误差作为损失。它告诉了我们直接 reconstruct image 原图也可以做到很 work。

在这里插入图片描述

MAE架构:Mask 掉输入图像的随机的 patches 并重建它们。它基于两个核心理念:研究人员提出了一个非对称编码器 - 解码器架构,其中Encoder编码器只对可见的 patch 子集进行操作 (即没有被 mask 掉的 token),Decoder解码器可以从潜在表征和被 masked 掉的 token 重建原始图像。Decoder 的架构可以是十分轻量化的模型,且具体的架构对模型性能影响很大。研究人员进一步发现,Mask 掉大部分输入图像 (例如 75%)会产生重要且有意义 的自监督任务。

在这里插入图片描述
MAE 方法严格来讲属于一种去噪自编码器 (Denoising Auto-Encoders (DAE)),去噪自动编码器是一类自动编码器,它破坏输入信号,并学会重构原始的、未被破坏的信号。MAE 的 Encoder 和 Decoder 结构不同,是非对称式的。Encoder 将输入编码为 latent representation,而 Decoder 将从 latent representation 重建原始信号。

在这里插入图片描述

MAE 和 ViT 的做法一致,将图像划分成规则的,不重叠的 patches。然后按照均匀分布不重复地选择一些 patches 并且 mask 掉剩余的 patches。作者采用的 mask ratio 足够高,因此大大减小了 patches 的冗余信息,使得在这种情况下重建 images 不那么容易。(Hard Sample思想,增大loss加速收敛),不同mask ratio下的实验(Fine-tuning是即微调Encoder也微调分类linear head;Linear-probing是冻住Encoder只微调linear head) 证明75%是一个比较合适的取值。

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

算法流程:

  • Patchify+Masking:首先将input image切分为patches,执行mask操作(75%),然后只把 可见的 patches(unmasked patches) 送入encoder中。再将encoder的输出(latent representations) 以及masked tokens(learnable embedding) 作为轻量级decoder的输入,decoder重构整张image。

  • Encoder: 编码器实际上就是ViT,将input image切分为不重叠的正方形patches之后,执行linear projection,再加上positional embeddings (the sine-cosine version) ,然后送入transformer blocks。

  • Decoder: 同样使用ViT,将masked tokens + unmasked tokens作为输入,加上位置编码 (the sine-cosine version) 。decoder的最后一层是linear projection,输出通道数量和一个patch内的pixel数量相同(方便重构),然后再reshape重构image。损失函数使用MSE,损失函数只对masked patches计算(和BERT相同)。同时作者也尝试了normalization的方式,即计算一个patch内像素值的均值和标准差,然后对patch执行normalization,此时encoder的重构任务发生了一些变化,需要重构normalized pixel values,实验表明这种方式效果更好一点。

  • MAE中decoder的设计并不重要,因为预训练结束之后,只保留encoder(作为backbone训练下游任务),decoder只需要完成预训练时的图像重构任务。但是作者也表示decoder决定了latent representations的语义级别。

# Patchify

# PatchEmbed

# Position Encoding

# Random_Shuffle and Mask 25% Token

# Add CLS token

# Encoder(ViT)

# MaskToken concat UnMaskedToken(learnable)

# UnShuffle

# Decoder(ViT)

# Mask MSE Loss

为什么 BERT (2018) 提出这么久以后,直到 BEIT (2021.6) 和 MAE (2021.11) 之前,一直在 CV 领域都没有一个很类似的 CV BERT 出现?

  1. CV 和 NLP 主流架构不同:直到 ViT (2020.12) 出现之前,CV 的主流架构一直是以卷积网络为主,NLP 的主流架构一直是以 Transformer 为主。卷积核作用在一个个的 grid 上面,直观来讲没法产生像 Transformer 一样的 token 的概念,也就是说如果我们只使用卷积网络,那么 image token 概念的建立就不那么直观。所以,像 Transformer 那样在 token 的基础上进行自监督学习就不太适用,这是第一个难点。
  2. 语言和图片 (视频) 的信息密度不同:语言是人类造就的信号,它 highly semantic , information-dense。而图片 (视频) 是自然产生的信号,它 heavy spatial redundancy。即挡住图片的一部分 patches,可以很容易地通过看它周围的 patches 而想象出它的样子来。所以,语言和图像,一个信息密度高,一个信息密度低,这是第二个难点。解决的办法是什么呢?作者提出了一个简单的策略:即挡住图片的 patches 的比例高一些。比如之前你挡住一张图片的 30% 的 patches,能够轻松通过周围的 patches 预测出来;那现在如果挡住图片的 90% 的 patches,还能够轻松通过周围的 patches 预测出来吗?
  3. AutoEncoder 里面的 Decoder 部分 (就是将映射得到的中间特征重建为 input 的模块) 在 CV 和 NLP 中充当的角色不同:在 CV 领域,Decoder 的作用是重建 image pixels,所以 Decoder 的输出语义级别很低。在 NLP 领域,Decoder 的作用是重建 sentence words ,所以 Decoder 的输出语义级别很丰富。

1.1 MAE Encoder

MAE Encoder 采用 ViT 架构,但只会作用于 unmasked images。和 ViT 思路一样,MAE Encoder 会先通过 Linear Projection 编码图片,再加上位置编码,随后送入一堆连续的 Transformer Block 里面。但是编码器只对整个图片 patches 集合的一个小子集 (例如25%)进行操作,而删除 masked patches(75%)。这里和 BERT 做法不一样,BERT 使用对于 mask 掉的部分使用特殊字符代替,而 MAE 不使用掩码标记。
在这里插入图片描述

代码实现:

Patch Embedding(timm库):它接受张量形状为 (batch_size, RGB_channels, height, width) 的图像。 通过执行线性投影为每个Patch获得嵌入, 这是通过 2D 卷积层来完成。 然后张量在最后一个维度被展平(压扁),变成 (batch_size, encoder_embed_dim, num_visible_patches),并 转置为形状(batch_size, num_visible_patches, encoder_embed_dim)的张量。

class PatchEmbed(nn.Module): 
    """ Image to Patch Embedding """ 
    def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768): 
        super().__init__() 
        self.img_size = img_size 
        self.patch_size = patch_size 
        self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 
        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 
 
    def forward(self, x, **kwargs): 
        B, C, H, W = x.shape 
        assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 
        x = self.proj(x).flatten(2).transpose(1, 2) 
        return x

Position Embedding:位置编码添加了有关每个Patch位置的信息。 使用“sine-cosine”版本而不是可学习的位置嵌入。 下面的这个实现是一维版本。

def get_sinusoid_encoding_table(n_position, d_hid):  
 
    def get_position_angle_vec(position):  
        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]  
 
    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])  
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i  
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1  
    return torch.FloatTensor(sinusoid_table).unsqueeze(0) 

ViT Block:与 Transformer 类似,每个块由归一化层Norm多头注意力模块Attention前馈层FFN组成。 中间输出形状是(batch_size, num_visible_patches, encoder_embed_dim)
多头注意力模块的代码如下:

class Attention(nn.Module): 
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., attn_head_dim=None): 
        super().__init__() 
        self.num_heads = num_heads 
        head_dim = attn_head_dim if attn_head_dim is not None else dim // num_heads 
        all_head_dim = head_dim * self.num_heads 
        self.scale = qk_scale or head_dim ** -0.5 
        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) 
        self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) if qkv_bias else None 
        self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) if qkv_bias else None 
        self.attn_drop = nn.Dropout(attn_drop) 
        self.proj = nn.Linear(all_head_dim, dim) 
        self.proj_drop = nn.Dropout(proj_drop) 
 
    def forward(self, x): 
        B, N, C = x.shape 
        qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) if self.q_bias is not None else None 
        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 
        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple) 
        q = q * self.scale 
        attn = (q @ k.transpose(-2, -1)).softmax(dim=-1) 
        attn = self.attn_drop(attn) 
        x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 
        x = self.proj_drop(self.proj(x)) 
        return x

ViT Block (from timm) 代码:

class Block(nn.Module): 
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_head_dim=None): 
        super().__init__() 
        self.norm1 = norm_layer(dim) 
        self.attn = Attention( 
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 
            attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim) 
        self.norm2 = norm_layer(dim) 
        self.mlp = nn.Sequential( 
            nn.Linear(dim, int(dim * mlp_ratio)), 
            act_layer(), 
            nn.Linear(int(dim * mlp_ratio), dim), 
            nn.Dropout(attn_drop) 
        ) 
 
    def forward(self, x): 
        x = x + self.attn(self.norm1(x)) 
        x = x + self.mlp(self.norm2(x)) 
        return x

总Encoder实现:这部分仅用于下游任务的微调。 论文的模型遵循 ViT 架构,该架构具有用于分类的类令牌(patch)。 因此,他们添加了一个虚拟CLS令牌,但是论文中也说到他们的方法在没有它的情况下也可以运行良好,因为对其他tokens执行了平均池化操作。 在这里也包含了实现的平均池化版本。 之后,添加一个线性层作为分类器。 最终的张量形状是 (batch_size, num_classes)。其实PatchEmbed和Block可以从timm.models.vision_transformer导入

class Encoder(nn.Module) 
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=nn.LayerNorm, num_classes=0, **block_kwargs): 
        super().__init__() 
        self.num_classes = num_classes 
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models 
 
        # Patch embedding 
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 
        num_patches = self.patch_embed.num_patches 
 
        # Positional encoding 
        self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim) 
 
        # Transformer blocks 
        self.blocks = nn.ModuleList([Block(**block_kwargs) for i in range(depth)])  # various arguments are not shown here for brevity purposes 
        self.norm =  norm_layer(embed_dim) 
 
        # Classifier (for fine-tuning only) 
        self.fc_norm = norm_layer(embed_dim) 
        self.head = nn.Linear(embed_dim, num_classes) 
 
    def forward(self, x, mask): 
        x = self.patch_embed(x) 
        x = x + self.pos_embed.type_as(x).to(x.device).clone().detach() 
        B, _, C = x.shape 
        if mask is not None:  # for pretraining only 
            x = x[~mask].reshape(B, -1, C) # ~mask means visible 
        for blk in self.blocks: 
            x = blk(x) 
        x = self.norm(x) 
        if self.num_classes > 0:  # for fine-tuning only 
            x = self.fc_norm(x.mean(1))  # average pooling 
            x = self.head(x) 
        return x

1.2 MAE Decoder

MAE Decoder 采用 Transformer 架构,输入整个图片 patches 集合,不光是 unmasked tokens (图中蓝色色块),还有被 mask 掉的部分 (图中灰色色块)。每个 mask tokens 都是一个共享的、learnable的embedding token,它指示了这里有一个待预测的 tokens。作者还将 PE 添加到这个完整 image patch 集合中的所有 tokens 中,位置编码表示每个 patches 在图像中的位置的信息。

MAE Decoder 仅用于预训练期间执行图像重建任务。因为自监督学习的特点就是只用最后预训练好的 Encoder 完成分类任务。因此,可以灵活设计与编码器设计无关的解码器结构。作者用比编码器更窄更浅的很小的解码器做实验。 在这种非对称的设计下,tokens 就可以由轻量级解码器处理,这大大缩短了预训练的时间。
在这里插入图片描述

Decoder代码实现:解码器由一系列transformer 块组成。 在解码器的末端,有一个由norm层前馈层组成的分类器。 输入张量的形状为 (batch_size, num_patches,decoder_embed_dim) 而最终输出张量的形状为 (batch_size, num_patches, 3 * patch_size ** 2)

class Decoder(nn.Module): 
    def __init__(self, patch_size=16, embed_dim=768, norm_layer=nn.LayerNorm, num_classes=768, **block_kwargs): 
        super().__init__() 
        self.num_classes = num_classes 
        assert num_classes == 3 * patch_size ** 2 
        self.num_features = self.embed_dim = embed_dim 
        self.patch_size = patch_size 
        self.blocks = nn.ModuleList([Block(**block_kwargs) for i in range(depth)])  # various arguments are not shown here for brevity purposes 
        self.norm =  norm_layer(embed_dim) 
        self.head = nn.Linear(embed_dim, num_classes) 
 
    def forward(self, x, return_token_num): 
        for blk in self.blocks: 
            x = blk(x) 
        if return_token_num > 0: 
            x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels 
        else: 
            x = self.head(self.norm(x)) 
        return x

总MAE模型结构代码:

from functools import partial

import torch
import torch.nn as nn

from timm.models.vision_transformer import PatchEmbed, Block
# from util.pos_embed import get_2d_sincos_pos_embed

class MaskedAutoEncoderViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3,
        embed_dim=1024, depth=24, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=Flase
    ):
    super().__init__()

    # Encoder的embed_dim 和 Decoder的embed_dim大小不同

    #---Encoder------------
    # patchify: [3,224,224]->[16*16,3,14,14]->[14*14,16*16*3]->[seq_len=196,embed_dim=1024]
    self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
    num_patches = self.patch_embed.num_patches
    
    self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))  # learnable
    # unlearnable(patch token + cls token)
    self.pos_embed = nn.Parameter(torch.zeros(1,num_patches+1,embed_dim), requires_grad=None, norm_layer=norm_layer)

    self.blocks = nn.ModuleList([
        Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
        for i in range(depth)
    ])
    self.norm = norm_layer(embed_dim)
    #----------------------

    #---Decoder------------
    self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bais=True)

    self.mask_token = nn.Parameter(torch.zeros(1,1,decoder_embed_dim)) # learnable
    self.decoder_pos_embed = nn.Parameter(torch.zeros(1,num_patches+1,decoder_embed_dim), requires_grad=None, norm_layer=norm_layer)

    self.decoder_blocks = nn.ModuleList([
        Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None,norm_layer=norm_layer)
        for i in range(decoder_depth)
    ])
    self.decoder_norm = norm_layer(decoder_embed_dim)
    self.decoder_pred  =nn.Linear(decoder_embed_dim, patch_size**2*in_chans, bais=True)
    #----------------------

    self.norm_pix_loss = norm_pix_loss
    self.initialize_weights()

    def initialize_weights(self):
        # PE替换为常量:initialize (and freeze) pos_embed by sin-cos embedding
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
        self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def patchify(self, imgs):
        """
        imgs: (N, 3, H, W)
        x: (N, L, patch_size**2 *3)
        """
        p = self.patch_embed.patch_size[0]
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
        return x

    def unpatchify(self, x):
        """
        x: (N, L, patch_size**2 *3)
        imgs: (N, 3, H, W)
        """
        p = self.patch_embed.patch_size[0]
        h = w = int(x.shape[1]**.5)
        assert h * w == x.shape[1]
        
        x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
        return imgs

    def random_masking(self, x, mask_ratio):
        N, L, D  = x.shape()  # batch, seq_len, embed_dim

        len_keep = int(L * (1 - mask_ratio)) # keep 前25% * seq_len的tokens

        # random_shuffle by sorting rand_noise: sort noise for each sample
        noise = torch.rand(N, L, device=x.device)
        # argsort return sorted_indexes
        ids_shuffle = torch.argsort(noise, dim=1) # dim=1 is ascend sort in seq_len dim: small is keep, large is masked
        ids_restore = torch.argsort(ids_shuffle, dim=1)  # for unshuffle
        """ example:
            noise = tensor([[0.2275, 0.5513, 0.4059, 0.2158, 0.9633]])
            ids_shuffle = tensor([[3, 0, 2, 1, 4]])
            ids_restore = tensor([[1, 3, 2, 0, 4]])
        """

        # keep the first subset(unmasked 25%)
        ids_keep = ids_shuffle[:,:len_keep]
        # sort x based on ids_keep in seq_len dim (dim=1)
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1,1,D))

        # generate the binary mask: 0 is keep, 1 is mask
        mask = torch.ones([N,L], device=x.device)  # all 1
        mask[:,:len_keep] = 0
        # sort mask to get the binary mask for x in seq_len dim (dim=1)
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def forward_encoder(self, x, mask_ratio):
        # embed patches tokens
        x = self.patch_embed(x)
        # add pos embed w/o cls token in seq_len dim
        x = x + self.pos_embed[:, 1:, :]

        # masking patch tokens: length -> length * mask_ratio
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:,:1,:]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)  # expand for batch_size
        x = torch.cat((cls_tokens, x), dim=1)

        # apply cls token
        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        return x, mask, ids_restore
    
    def forward_decoder(self, x, ids_restore):
        # embed tokens(embed_dim: 1024->512)
        x = self.decoder_embed(x)

        # append mask tokens to sequence
        # [1,1,embed_dim] -> [batch_size, (mask+unmasked)_len - (unmasked_len - 1), embed_dim], (-1) is because x include cls token, but ids_restore not include cls token.
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        # [batch_size, (mask+unmask)_len, embed_dim]
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
        # [batch_size, (mask+unmask+cls)_len, embed_dim]
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token

        # add pos embed
        x = x + self.decoder_pos_embed

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        # predictor projection: [batch_size, (mask+unmask+cls)_len, embed_dim=512] -> [batch_size, (mask+unmask+cls)_len, embed_dim=16**2*3=768]
        x = self.decoder_pred(x)

        # remove cls token: [batch_size, (mask+unmask)_len=14*14=196, embed_dim=16*16*3=768]
        x = x[:, 1:, :]

        return x

def forward_loss(self, imgs, pred, mask):
        """
            Normlized Masked MSE loss
        imgs: [N=B, 3, H, W]
        pred: [N, L, p*p*3]
        mask: [N, L], 0 is keep, 1 is remove, 
        """
        # [B, 3, H=224, W=224] -> [B, 3, num_patches=14*14, h*w*3=16*16*3]
        target = self.patchify(imgs)
        # normlize target
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5

        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

        # masked loss: 0/1_mask * MSE loss, 1 is masked pixel, 0 is unmasked pixel.
        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed unmasked patches
        return loss

    def forward(self, imgs, mask_ratio=0.75):
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
        pred = self.forward_decoder(latent, ids_restore)  # [N, L, p*p*3]
        loss = self.forward_loss(imgs, pred, mask)
        return loss, pred, mask


def mae_vit_base_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_large_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_huge_patch14_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=14, embed_dim=1280, depth=32, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


# set recommended archs
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b  # decoder: 512 dim, 8 blocks

1.3 自监督学习目标函数 Reconstruction Target

Decoder 的最后一层Head是一个 Linear Projection 层,其输出的 channel 数等于图像的像素 (pixel) 数。所以 Decoder 的输出会进一步 reshape 成图像的形状。损失函数就是 MSE Loss,即直接让 reconstructed image 和 input image 的距离越接近越好,但注意我们只对Masked token计算loss(Mask Loss)。

作者还尝试了另外一种损失函数,就是先计算出每个 patch 的像素值的 mean 和 deviation,并使用它们去归一化这个 patch 的每个像素值。最后再使用归一化的像素值进行 MSE Loss 计算。但是发现这样做的效果比直接 MSE Loss 好。(归一化的Mask Loss

1.4 具体实现

MAE 的具体实现方法是:

  • 首先通过 Linear Projection 和位置编码得到 image tokens。
  • 随机 shuffle 这些 tokens,按照 masking ratio 扔掉最后的一部分。
  • 把 unmasked patches 输出到 Encoder 中,得到这些 tokens 的表征。
  • 把 Encoder 的输出,结合 masked tokens (可学习的向量),执行 unshuffle操作恢复顺序,再一起输入到 Decoder 中。
  • shuffle 和 unshuffle 操作的时间开销可忽略不计。
class MAE(nn.Module): 
    def __init__(self, ...):  # various arguments are not shown here for brevity purposes 
        super().__init__() 
        self.encoder = Encoder(img_size, patch_size, in_chans, embed_dim, norm_layer, num_classes=0, **block_kwargs) 
        self.decoder = Decoder(patch_size, embed_dim, norm_layer, num_classes, **block_kwargs) 
        self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=False) 
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 
        self.pos_embed = get_sinusoid_encoding_table(self.encoder.patch_embed.num_patches, decoder_embed_dim) 
 
    def forward(self, x, mask): 
        x_vis = self.encoder(x, mask) 
        x_vis = self.encoder_to_decoder(x_vis) 
        B, N, C = x_vis.shape 
        expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach() 
        pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C) 
        pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C) 
        x_full = torch.cat([x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1) 
        x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16] 
        return x

MAE的优势

(1)Scalable:encoder只操作可见patches,把mask tokens给本身参数就不多的decoder去运算,大大降低了计算量,尤其当mask的比例很高的时候,大大减少了预训练时间,让MAE可以很轻松的scale到更大的模型上(enabling us to easily scale MAE to large models),并且通过实验发现随着模型增大,效果越来越好

(2)高容量且泛华性能好(very high-capacity models that generalize well):使用MAE预训练方法,可以训练很大的model,比如ViT-Large/Huge,当把预训练好的ViT-Huge迁移到下游任务时,模型表现非常好,甚至超过了使用监督预训练的相同模型(achieves better results than its supervised pre-training counterparts),这说明MAE预训练学习到的表示可以很好的泛化到下游任务(these pre-trained representations generalize well to various downstream task)

2. 实验分析

在ImageNet-1K上自监督预训练,使用标准ViT结构,预训练后,使用encoder进行微调和linear probing,因为是用于图像分类,所以类似于ViT,在输入加一个class token(an auxiliary dummy token),实验结果表明使用average pooling可以达到相同的效果

不同ViT-B/L/H的区别在于Block的深度、Token Embedding的大小(token dim)、MLP的size、Attention Head。
在这里插入图片描述

(0)消融实验
在这里插入图片描述

(1)预训练阶段

没有使用color jittering(数据增强的方式之一)、drop path(dropout的变体)、gradient clip(设置阈值预防梯度爆炸/消失)。是ViT官方代码相同,使用xavier uniform初始化所有Transformer blocks。使用linear learning rate scaling rule

在这里插入图片描述

(2)端到端微调

使用layer-wise learning rate decay

在这里插入图片描述

(3)linear probing

训练设置参考MoCov3,linear probing和端到端微调有很大不同,regularization对linear probing来说可能会损失模型性能,因此和MoCov3中一样,舍弃了一些regularization strategies
在这里插入图片描述

(4)部分微调(partial fine-tune):

linear probing缺少非线性建模能力(it misses the opportunity of pursuing strong but non-linear features—which is indeed a strength of deep learning),partial fine-tune 只微调encoder最后个layers,其超参数等设置和微调时相同的(table 9),除了调整了fine-tunning epochs

四个阶段均计算top-1 accuracy(224x224),使用ViT-Large作为baseline,进行ablation study。对比ViT-Large 从头训练(200 epochs)和微调(50 epochs)两种方式,可以发现train from scratch效果并不如微调
在这里插入图片描述

用 MAE 做 pre-training 只需 ImageNet-1k 就能达到 87.8% 的 Top-1 准确度,超过了所有在 ImageNet-21k pre-training 的 ViT 变体模型。而从方法上看,MAE 选择直接重建原图的元素,而且证明了其可行性,改变了人们的认知,又几乎可以覆盖 CV 里所有的识别类任务,看起来像是开启了一个新的方向。直接重建原图的元素是非常重要的,因为通过这个形式,作者就用最最直观的方式完成了 MIM 任务,使得 MIM的潜力逐步被证实。从 MLM 到 MIM 的过渡已被证明,由此观之比肩 GPT3 的 CV 预训练大模型已不远矣。

### MAE(Masked Autoencoder)预训练概述 #### 背景介绍 MAE 是一种用于无监督学习的自编码器框架,特别适用于大规模视觉任务的学习。它通过遮蔽部分输入图像并尝试重建这些缺失的部分来实现特征提取和表示学习[^1]。 #### 关键概念解释 Fine-tuning 和 Linear Probing 是评估模型性能的重要手段。Fine-tuning 表示在整个网络上进行参数调整以适应特定下游任务;而 Linear Probing 则是在冻结预训练模型的基础上仅训练一个线性分类层来进行测试[^2]。 #### 技术细节分析 具体来说,PyTorch 中实现了 MAE 的核心流程如下:首先为输入图片创建 token embedding 并附加位置编码(Positional Encoding),可以采用固定的正弦/余弦形式或者是可学习的位置嵌入方式。接着打乱这些 tokens 的顺序,并选取其中前 25% 进行处理,剩余未选中的则填充由共享且可学习的 mask embeddings 构成的内容作为占位符。随后将上述准备好的数据送入 Encoder 获取编码后的 patches 特征向量(embeddings) 。之后把这些结果与之前保留下来的那批未经修改的数据重新组合起来恢复原始排列次序(unshuffling process),最终传递给 Decoder 完成整个重构过程[^3]。 #### 性能对比展示 当应用于 COCO 数据集的目标检测以及实例分割任务时,基于 ViT 骨干网路并通过 FPN 结构增强后发现,在不同规模设置下均优于传统监督式预先训练方案。特别是针对较大尺寸版本(ViT-Large),利用 MAE 方法所获得的表现提升了整整四个百分点(从原来的49.3提升到了现在的53.3)[^4]。 ```python import torch from torchvision import transforms, datasets from mae_model import MAEModel # 自定义模块导入 def main(): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) dataset = datasets.ImageFolder(root='path_to_dataset', transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=8) model = MAEModel() for images, _ in dataloader: outputs = model(images) break if __name__ == '__main__': main() ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Yuezero_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值