respect恺明 | 手撸MAE,无缝嵌入mm系列框架(分类/检测/分割)

来源:知乎—tik boa

地址:https://zhuanlan.zhihu.com/p/444310267


01

概要

最近只要你还在cv这个圈子中卷着,就一定听说过kaiming大神的MAE。在这之前自监督的范式,主要基于对比学习来做,诸如Moco系列、SimCLR、BYOL、SimSiam、SimCSE等等。后续主要通过重建的方式进行自监督习,BEiT以重建token的方式作为proxy task,而MAE表明重建pixel完全是可行的,即使在75%的mask比率下。

此外,基于MAE的自监督方式,在下游任务中表现也相当优异。如下表所示,在coco数据集上,基于MAE的vit模型,相较于有监督、moco、BEiT等在box/mask AP上均有大幅度提升。

0bfe62c294a03995e58badb4a8eaa7f7.png

coco目标检测与实例分割

在此之间,训练custom dataset的下游任务网络时,基本基于ImageNet预训练模型,但是这存在两大问题:

  • custom dataset与ImageNet类似还好,但若是医学/工业数据,其实是两个domain,存在gap;

  • custom dataset的有标注数据获取费时、费力,而大量无标注数据却未充分利用;

为此,本篇文章将叙述如何手撸一个MAE,搞懂细节,知其然并知其所以然。然后在此基础上,迁移到mm系列框架中(mmdet/mmcls/mmseg),迅速构建一个属于自己的强有力的backbone。废话不多说了,让我们先进入MAE的世界。


02

MAE (Masked Autoencoders)

MAE全文非常简洁明了、整篇文章没有一个公式。即使你初入cv,读完整篇文章也大概知道在讲什么了。MAE编解码结构如下所示,编码器输入未mask的patch, 解码器输入编码器输出+mask patch的embeding。

de004d53c1021f02b3319e887e6c61c0.png

MAE

对全文进行切分,主要涉及五大部分:

  • encoder 构建;

  • decoder 构建;

  • MAE 构建;

  • pretaining pipeline;

  • finetuning pipeline;

接下来将对以上五个部分抽丝剥茧的介绍。

encoder

编码器采用ViT结构,但去除了vit中的分类head,剩余结构包含以下部分:

  • Patch embedding

  • Transformer encoder

    • Attention

    • MLP

3a8b663eaeac05f87163813851d7aa40.png

patch embedding

该部分主要将一张图片切块,进行维度压缩,然后转为为一个序列,具体形式如下所示。

335be7e95372ecab36cb92a7c108f52b.png

patch embedding

该部分代码:

class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches


        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)


    def forward(self, x, **kwargs):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

attention

attention 机制实质上就是一个寻址过程,通过给定一个任务相关的查询 Query 向量 Q,通过计算与 Key 的注意力分布并附加在 Value 上,从而计算 Attention Value。

Attention 机制计算过程大致可以分成三步:

① 信息输入:将 Q,K,V 输入模型

用  表示输入权重向量

② 计算注意力分布 α:通过计算 Q 和 K 进行点积计算相关度,并通过 softmax 计算分数

另  ,通过 softmax 计算注意力权重, α 

我们将α_i称之为注意力概率分布,  为注意力打分机制,常见的有如下几种:

加性模型:  

点积模型:  

缩放点积模型:  

双线性模型:  

③ 信息加权平均:注意力分布α_i来解释在上下文查询q_i时,第i个信息受关注程度。

 α 

c24b6c80b57046a1e00a31a21912281f.png

attention结构

Multi-head Attention就是将 Scaled Dot-Product Attention 过程做 H 次,再把输出合并起来:

多头注意力机制的公式如下:

75293e6efef1bd67110727492eddd88d.png

代码如下:

class Attention(nn.Module):
    def __init__(
            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
            proj_drop=0., attn_head_dim=None):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        if attn_head_dim is not None:
            head_dim = attn_head_dim
        all_head_dim = head_dim * self.num_heads
        self.scale = qk_scale or head_dim ** -0.5


        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
        if qkv_bias:
            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
        else:
            self.q_bias = None
            self.v_bias = None


        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(all_head_dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)


    def forward(self, x):
        B, N, C = x.shape
        qkv_bias = None
        if self.q_bias is not None:
            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)


        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))




        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)


        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

MLP

mlp较为简单,就是两个线性层的叠加,没什么可说的,代码如下:

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)


    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        # x = self.drop(x)
        # commit this for the orignal BERT implement 
        x = self.fc2(x)
        x = self.drop(x)
        return x

组装编码器

其中pos_embed可采用两种形式:

① 可学习的:self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
② 生成的:self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)

生成形式如下:

def get_sinusoid_encoding_table(n_position, d_hid): 
    ''' Sinusoid position encoding table ''' 
    # TODO: make it with torch instead of numpy 
    def get_position_angle_vec(position): 
        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 


    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 


    return torch.FloatTensor(sinusoid_table).unsqueeze(0)

最后MAE的encoder代码如下:

class PretrainVisionTransformerEncoder(nn.Module):
    """ Vision Transformer with support for patch or hybrid CNN input stage
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
                 use_learnable_pos_emb=False):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models


        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches


        if use_learnable_pos_emb:
            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        else:
            # sine-cosine positional embeddings 
            self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)


        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
                init_values=init_values)
            for i in range(depth)])
        self.norm =  norm_layer(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()


        if use_learnable_pos_emb:
            trunc_normal_(self.pos_embed, std=.02)


        # trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)


    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)


    def forward_features(self, x, mask):
        x = self.patch_embed(x)


        x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()


        B, _, C = x.shape
        x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible


        for blk in self.blocks:
            x_vis = blk(x_vis)


        x_vis = self.norm(x_vis)
        return x_vis


    def forward(self, x, mask):
        x = self.forward_features(x, mask)
        x = self.head(x)
        return x

NOTE:

通过 x_vis = x[~mask].reshape(B, -1, C) 方式 使得encoder 仅在未mask的patch上操作

decoder

解码器和编码器结构类似,值得注意的是:

  • 输入是全部patch embeding;

  • 输出是mask patch的解码输出;

这部分后续会细讲

class PretrainVisionTransformerDecoder(nn.Module):
    """ Vision Transformer with support for patch or hybrid CNN input stage
    """
    def __init__(self, patch_size=16, num_classes=768, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, num_patches=196,
                 ):
        super().__init__()
        self.num_classes = num_classes
        assert num_classes == 3 * patch_size ** 2
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.patch_size = patch_size


        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
                init_values=init_values)
            for i in range(depth)])
        self.norm =  norm_layer(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()


        self.apply(self._init_weights)




    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)


    def forward(self, x, return_token_num):
        for blk in self.blocks:
            x = blk(x)


        x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels
        return x

MAE

MAE的encoder结合decoder需要注意以下几点:

  • endocer(vitbase)通道为768, decoder为384,输入decoder时候,pos_embed重新生成;

  • decoder输出类别为768= (16 x 16 x3), 为patch的RGB像素值;

  • decoder 仅输出 mask patch的结果
    即上文的:x = self.head(self.norm(x[:, -return_token_num:])) #concat过程 后面是mask的预测

整体过程如下所示,看完这个应该没啥不明白的了吧:

6f402759275a26287e2ad1e6075dcbea.png

MAE代码如下:

class PretrainVisionTransformer(nn.Module):
    """ Vision Transformer with support for patch or hybrid CNN input stage
    """
    def __init__(self,
                 img_size=224, 
                 patch_size=16, 
                 encoder_in_chans=3, 
                 encoder_num_classes=0, 
                 encoder_embed_dim=768, 
                 encoder_depth=12,
                 encoder_num_heads=12, 
                 decoder_num_classes=768, 
                 decoder_embed_dim=512, 
                 decoder_depth=8,
                 decoder_num_heads=8, 
                 mlp_ratio=4., 
                 qkv_bias=False, 
                 qk_scale=None, 
                 drop_rate=0., 
                 attn_drop_rate=0.,
                 drop_path_rate=0., 
                 norm_layer=nn.LayerNorm, 
                 init_values=0.,
                 use_learnable_pos_emb=False,
                 num_classes=0, # avoid the error from create_fn in timm
                 in_chans=0, # avoid the error from create_fn in timm
                 ):
        super().__init__()
        self.encoder = PretrainVisionTransformerEncoder(
            img_size=img_size, 
            patch_size=patch_size, 
            in_chans=encoder_in_chans, 
            num_classes=encoder_num_classes, 
            embed_dim=encoder_embed_dim, 
            depth=encoder_depth,
            num_heads=encoder_num_heads, 
            mlp_ratio=mlp_ratio, 
            qkv_bias=qkv_bias, 
            qk_scale=qk_scale, 
            drop_rate=drop_rate, 
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate, 
            norm_layer=norm_layer, 
            init_values=init_values,
            use_learnable_pos_emb=use_learnable_pos_emb)


        self.decoder = PretrainVisionTransformerDecoder(
            patch_size=patch_size, 
            num_patches=self.encoder.patch_embed.num_patches,
            num_classes=decoder_num_classes, 
            embed_dim=decoder_embed_dim, 
            depth=decoder_depth,
            num_heads=decoder_num_heads, 
            mlp_ratio=mlp_ratio, 
            qkv_bias=qkv_bias, 
            qk_scale=qk_scale, 
            drop_rate=drop_rate, 
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate, 
            norm_layer=norm_layer, 
            init_values=init_values)


        self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=False)


        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))


        self.pos_embed = get_sinusoid_encoding_table(self.encoder.patch_embed.num_patches, decoder_embed_dim)


        trunc_normal_(self.mask_token, std=.02)




    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)


    def forward(self, x, mask):


        x_vis = self.encoder(x, mask) # [B, N_vis, C_e]
        x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d]


        B, N, C = x_vis.shape


        # we don't unshuffle the correct visible token order, 
        # but shuffle the pos embedding accorddingly.
        expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
        pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
        pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C)
        x_full = torch.cat([x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1)
        # notice: if N_mask==0, the shape of x is [B, N_mask, 3 * 16 * 16]
        x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16]


        return x

pretaining pipeline

预训练过程,有个需要注意的地方:

对每个path的像素进行归一化,减少训练难度

images_squeeze = rearrange(unnorm_images, 'b c (h p1) (w p2) -> b (h w) (p1 p2) c', p1=patch_size, p2=patch_size)
 images_norm = (images_squeeze - images_squeeze.mean(dim=-2, keepdim=True)
                    ) / (images_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
images_patch = rearrange(images_norm, 'b n p c -> b n (p c)')
labels = images_patch[bool_masked_pos].reshape(B, -1, C)

其中bool_masked_pos为mask的图像patch索引,生成方式如下所示:

random mask generator

class RandomMaskingGenerator:
    def __init__(self, input_size, mask_ratio):
        if not isinstance(input_size, tuple):
            input_size = (input_size,) * 2


        self.height, self.width = input_size


        self.num_patches = self.height * self.width
        self.num_mask = int(mask_ratio * self.num_patches)


    def __repr__(self):
        repr_str = "Maks: total patches {}, mask patches {}".format(
            self.num_patches, self.num_mask
        )
        return repr_str


    def __call__(self):
        mask = np.hstack([
            np.zeros(self.num_patches - self.num_mask),
            np.ones(self.num_mask),
        ])
        np.random.shuffle(mask)
        return mask # [196]

训练代码如下:

def train_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, patch_size: int = 16, 
                    lr_scheduler=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None):
    model.train()
    loss_func = nn.MSELoss()
    for step, (batch, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        it = start_steps + step
        if lr_schedule_values is not None or wd_schedule_values is not None:
            for i, param_group in enumerate(optimizer.param_groups):
                if lr_schedule_values is not None:
                    param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
                if wd_schedule_values is not None and param_group["weight_decay"] > 0:
                    param_group["weight_decay"] = wd_schedule_values[it]


        images, bool_masked_pos = batch
        images = images.to(device, non_blocking=True)
        bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool)


        # import pdb; pdb.set_trace()
        with torch.no_grad():
            # calculate the predict label
            mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None]
            std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None]
            unnorm_images = images * std + mean  # in [0, 1]


            images_squeeze = rearrange(unnorm_images, 'b c (h p1) (w p2) -> b (h w) (p1 p2) c', p1=patch_size, p2=patch_size)
            images_norm = (images_squeeze - images_squeeze.mean(dim=-2, keepdim=True)
                ) / (images_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
            images_patch = rearrange(images_norm, 'b n p c -> b n (p c)')


            B, _, C = images_patch.shape
            labels = images_patch[bool_masked_pos].reshape(B, -1, C)


        with torch.cuda.amp.autocast():
            outputs = model(images, bool_masked_pos)
            loss = loss_func(input=outputs, target=labels)


        loss_value = loss.item()


        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)


        optimizer.zero_grad()
        is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
        grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
                                parameters=model.parameters(), create_graph=is_second_order)




        if lr_scheduler is not None:
            lr_scheduler.step_update(start_steps + step)

finetuning pipeline

微调过程,有以下几个需要注意的地方:

  • 层级差分学习率:

class LayerDecayValueAssigner(object):
    def __init__(self, values):
        self.values = values


    def get_scale(self, layer_id):
        return self.values[layer_id]
assigner = LayerDecayValueAssigner(list(0.75 ** (12 + 1 - i) for i in range(12 + 2)))

encoder 12 层,每层学习率的缩放值为:

[0.023757264018058777, 0.03167635202407837, 0.04223513603210449, 0.056313514709472656, 0.07508468627929688, 0.1001129150390625, 0.13348388671875, 0.177978515625, 0.2373046875, 0.31640625, 0.421875, 0.5625, 0.75, 1.0]

采用这种策略主要是因为,底层主要负责一些固定的模式、高层负责语义。差分学习率类似于从底层到高层逐渐解冻。

  • 利用更广泛的数据增强与正则:

    • augmentation RandAug (9, 0.5)

    • label smoothing 0.1

    • mixup 0.8

    • cutmix 1.0

    • drop path 0.1 (B/L) 0.2 (H)


03

结果

放一张效果图吧:

d8e0b2bc8f3ab3ef0f104c4b2c5a05f5.png

无缝嵌入mm系列框架,待续~

猜您喜欢:

a78e3e4d8e6135808c15f047f272e3d5.png 戳我,查看GAN的系列专辑~!

一顿午饭外卖,成为CV视觉的前沿弄潮儿!

超110篇!CVPR 2021最全GAN论文汇总梳理!

超100篇!CVPR 2020最全GAN论文梳理汇总!

拆解组新的GAN:解耦表征MixNMatch

StarGAN第2版:多域多样性图像生成

附下载 | 《可解释的机器学习》中文版

附下载 |《TensorFlow 2.0 深度学习算法实战》

附下载 |《计算机视觉中的数学方法》分享

《基于深度学习的表面缺陷检测方法综述》

《零样本图像分类综述: 十年进展》

《基于深度神经网络的少样本学习综述》

d24cf878531a720771fdfc50742f299f.png

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值