【深度学习】 MAE|心中无码,便是高清

在之前一篇推文一文串起从NLP到CV 预训练技术和范式演进中,由于篇幅有限,仅仅介绍了深度学习中的预训练技术发展,基本思路是顺着CV和NLP双线的预训练技术发展演进。

这里正式开启一个顺着这篇推文的倒叙精读系列。

Masked Autoencoders Are Scalable Vision Learners

正好mae的官方pytorch在两周前开源了

https://github.com/facebookresearch/mae

我们用倒叙的方式,从MAE往后看。开篇用一个非常夸张的实验效果demo图。这个效果实在是太夸张了,人都脑补不出来这样的马赛克程度。

4e67c19cf720cab1ae1a8a1976198030.png

摘要

MAE的方法非常简单,随机MASK住图片里的一些块,然后再去重构这些被MASK住的像素。这整个思想也来自 BERT 的带掩码的语言模型,但不一样的是这一个词(patches) 它就是一个 image 的一个块,然后它预测的是你这块里面的所有的像素。

全文有两个重要的创新点:跑得快+学得难

跑得快:非对称的自编码器架构(autoencoder),其编码器仅作用在可见的这些patch里面, 如果一个 patch 被它丢掉了,那么编码器就不会对它进行编码。这样图像encoder端的工作量就减少了,好处就是可以跑得很快。进一步地,解码器是一个比较轻量的解码器。一层transformer就够用。

学得难:预训练任务能够重构原始的像素级图片。并且,可以搞定75%的这些块全部遮住下的图像复原。这个事情是一个非平凡的,而且有意义的自监督的任务。如果你就简单遮住几块的话,那么就插一下值,你就可以出来了,这样整个模型可能学不到特别有意思的东西。但是你要是遮住高达75%的部分,苦一苦你的模型,说不定他会学到一些更好的一些表征

然后把这两个放在一起,跑得快+学得难,我们就可以让他做一些超越自己当前模型水平的水平的事情(老PUA了)

结果:用更小的数据来自监督预训练,超越了更多数据监督训练的ViT模型。他用来自于VIT这个论文的不加任何技巧的ViT-Huge的模型backbobe结构,

加上他的预训练方法,能够得到 87.8% 的ACC表现。

最后,强调一下迁移学习也很好。当然,预训练模型不迁移学习,那岂不预训练了一个寂寞。

结构

论文一般有两个图最重要,一个是第一页右上角的小图,第二个是第三页横跨双栏的大图。

4cb3ed6026c75eab7ce2749ed62fe7f4.png

这是 MAE体的架构图,预训练阶段一共分为四个部分,MASK,encoder,decoder。

MASK

可以看到一张图片进来,首先把你切块切成一个一个的小块,按格子切下来。

其中要被MASK住的这一块就是涂成一个灰色,然后没有MASK住的地方直接拎出来,这个地方75%的地方被MASK住了。注意是随机采样,而不是什么中心采样,网格的采样,局部采样等方式,s这部分在实验里对比过。这里比较符合认知的解释是,可以防止引入类似中心归纳偏好等特定bias,随机是最公平的。

encoder

前面拎起来的像素块即unmask部分,放进一个 encoder 的里面,这里采用了ViT论文中的transformer backbone,得到每一个块它对应的这一些特征。

在这个地方它要把它拉长,把这些被MASK那些块,重新放回到原来的位置,把它拉成一条向量。在预训练的时候,MASK住的东西,其实啥也没有了,作者给了他一个可以学习的共享隐向量+Position  embedding(!!!!这个地方比较难trick,推荐看一下代码实现)没有MASK住的,就是填上那 ViT 它出来的这些特征。组成一个长的隐层向量,输到一个解码器里面。

decoder

解码器会去尝试把里面的像素信息全部重构回来,得到最后的 target(目标的像素值)。要注意的是,解码的过程是没有加速度的,但是解码的模型一般都不大。我们知道编码的Transformer 这些模型计算量都特别大,如果有个几倍的加速,其实也是非常重要的一个事情。

下游任务

如果你想用这个模型来做一个下游任务呢,你就只需要它的编码器就行了,解码器是不需要的,你的图片进来你不需要对它做掩码

你直接切成这些格子块。然后过encoder它就会得到你所有那些块的一个特征的表示,这个就是你的图片的语义表征(representation)

实现细节

encoder

1.patch,图像切块, 图像在tensor中的表示为 (B,C,H,W) reshape 成 (B,N,PxPxC),其中B是Batch大小,N和P分别为 patch 数量 和 patch 大小。

N = H*W/P/P。

2.patch embedding, 1中的图片切块的嵌入表征,他是连续值经过一层全连接得到固定维度大小的值(dim),注意文本是one-hot形式,或者look up table的形式。

从1中的 (B,N,PxPxC) -> (B,N,dim)

3.position embedding,patch编码对应的embeding,这个和NLP中的词表查到的embedding是一样的。

4.部分编码,预训练阶段的Encoder从实现角度再复述一遍:图像切块-没有MASK的部分走patch embedding+position embedding

def forward_encoder(self, x, mask_ratio):
        # embed patches
        x = self.patch_embed(x)

        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        return x, mask, ids_restore

decoder

1.mask部分的对应的隐向量并不来自于encoder的推断,而是直接在这里进行凭空初始化的共享token向量+position embedding。

2.decoder不需要用encoder那么重的模型。你可以理解为Bert的decoder就是个MLP,这里可以用一个特别简单的一层transformer。虽然decoder在数量补齐了复杂度,因为模型简单,压力并不太大。

def forward_decoder(self, x, ids_restore):
        # embed tokens
        x = self.decoder_embed(x)

        # append mask tokens to sequence
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token

        # add pos embed
        x = x + self.decoder_pos_embed

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        # predictor projection
        x = self.decoder_pred(x)

        # remove cls token
        x = x[:, 1:, :]

        return x

loss

1.仅用MSE算mask path的像素差值。只算mask patch是因为实验结论,否则有大约0.5%的ACC下降。

2.归一化的像素值作为target比较好。实验结果

总结

预训练阶段:

1.图片切patch

2.patch做embedding (projection方式)

3.加上position embedding (lookup table方式)

4.mask打码(75%)

5.无码部分进encoder

6.有码部分做好可训练的共享语义向量+position embedding

7.按patch的原始顺序拼好mask和unmask的对应语义向量,送decoder

8.取decoder出来的,mask部分对应的像素值算mse loss。

实验部分

1.mask比例

0289348def50855989dc4ec51263dd75.png

少了多了都不好。所以说恰当的压力才是前进的动力。直观理解就是太简单了学不到东西,太难了也学不会。

2.采样策略

随机采样效果最好,其他的方式多多少少泛化能力都差一点。

d626d98a2c524ca418396d0ec471f579.png

block的任务更难,扣掉一大块比例太大也学不好,对于模型来说太难了。扣掉50%差不多了,但是效果比随机还差一点。和上面个实验一样,刚刚好比较好。

3.decoder设计

58afac89dd7691b0fb55f2a839d2d0a4.png

用深层和更大decoder不太好。其实也可以理解,encoder出来的隐向量的信息已经够复杂了。第二点是苦一苦encoder,这样在下游任务他发挥的更好一点。要是用复杂的encoder,信息和建模能力,都隐藏在decoder恐怕就没有这么好的效果了。

4.重建目标

作者和 BEiT 那种预测token的方式 以及 PCA 的方式。patch 做 PCA 并预测最大的因子,进行了比较。有无归一化也进行了比较。

ccf3db20fc1babd146635cd4413e40e3.png

5.数据增强

保持图片局部完整信息的随机缩放,比其他引入噪声的方式都要好。

acfda4be04845f1e39e2d6f3949ce402.png

 
 
 
 
 
 
 
 
 
 
往期精彩回顾




适合初学者入门人工智能的路线及资料下载(图文+视频)机器学习入门系列下载中国大学慕课《机器学习》(黄海广主讲)机器学习及深度学习笔记等资料打印《统计学习方法》的代码复现专辑机器学习交流qq群955171419,加入微信群请扫码

3d76c6d5f3a4a976ab7ca12c63f85b26.png

  • 4
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值