【学习日记week7】VQVAE(dVAE),DALL-E,BEiT:基于生成/跨模态生成的工作概述

写在前面(为“日记”加一点亲切感)

写下这段内容的时候笔者正在医院看病,上个特别好的周末时光被至今也不知道是什么的病和发烧给毁了。。这周末又是考s试答辩连着来,大四了人依旧很崩溃。。。

BEiT其实在上周就已经看了一半了,然后头痛欲裂在床上连躺三天,吃了好多药也没退烧,嗓子现在和刀割一样><。上周意识到BEiT和COTS用的都是VQVAE,这个我带来了很大的兴趣,这种方法不仅仅能够用于跨模态预训练(相比基于GAN和Diffusion,VQVAE成图速度应该会快一些),而且在我毕设的工作中,也可以成为缺失模态补全的一种手段,且VQVAE的可解释性很强,所以我决定中断BEiT和DALL-E的阅读,首先先认真学学VQ-VAE。

VQVAE

VQVAE和VAE的区别是借鉴了VQ的思想,消除了“后验坍塌”的问题。

contribution

  1. 介绍了VQ-VAE方法,能够不受到**后验坍塌(posterior collapse)**的影响
  2. VQ-VAE方法和基于连续的似然的方法能有一样好的效果
  3. 给定一个强大的先验,可以生成高质量的样本
  4. 可以在其他领域进行学习任务

VAE

AE是进行一个数据压缩,编码再解码的过程,而VAE就是通过编码器-解码器网络,通过对中间的分布进行参数化估计,然后进行对原本分布的学习。
基本的概念:

  • 对于VAE,编码器的部分是对于观测到的 x x x进行参数化的过程,从而得到随机的隐含变量 z z z,有后验分布 p ( z ∣ x ) p(z|x) p(zx)
  • 然后将 p ( z ) p(z) p(z)作为先验分布,通过 p ( x ∣ z ) p(x|z) p(xz) x x x进行建模

具体的VAE方法可以看VAE的介绍以及变分推理的内容。本文的VQVAE的灵感来自于VQ(向量数字化)的想法,让生成器不再在连续的高斯分布中生成,而是在离散分布中生成

离散隐变量

首先定义一个K个的离散隐空间(也可以理解为K个离散的隐向量,每个的维度是 D D D),即有latent embedding space: e ∈ R K × D e\in \mathbb R^{K\times D} eRK×D
在这里插入图片描述

简述一下这个框架:首先将输入x通过encoder生成embedding z e ( x ) z_e(x) ze(x),然后通过一个最近邻查找的方法,来生成潜在空间 e e e中的离散的潜在变量 z z z,公式化表述为(注意,因为是最近邻查找,所以这里的 q ( z ∣ x ) q(z|x) q(zx)其实是一个onehot的形式):
在这里插入图片描述
然后docoder的输入是最近邻的这个相关联的 e k e_k ek,有 k = arg min ⁡ j ∥ z e ( x ) − e j ∥ 2 k=\argmin_j\Vert z_e(x)-e_j\Vert_2 k=argminjze(x)ej2,而输入则有 z q ( x ) = e k z_q(x)=e_k zq(x)=ek

在VAE中,训练的损失有一个重构损失和一个KL散度,但是在VQ-VAE中,因为后验概率是一个one hot的概率分布,所以其实这个KL散度是一个常数logk,直接可以忽略掉。

在训练的时候有三个loss,首先是VQ部份有两个loss,都是通过l2范数实现的loss(本质上应用的是字典学习的方法),来将潜在的embedding space和encoder的输出进行一个相互拉近的学习。最后一个loss是对于VAE而言的,一个重建损失。这个重建损失可以根据具体的重建任务去进行设计,本文中则是用了一个似然函数。

具体到,本文的损失函数如下
在这里插入图片描述
其中第一项是重建损失(VAE损失ELBO中的KL散度项,会因为onehot设计变为常数),第二项是字典学习的损失,而第三项是防止字典无限膨胀的commitment loss(让整个学习过程“双向奔赴”)。注意式子中的 s g sg sg是锁定梯度的意思。

但是根据VQ的设计,很容易发现其中用到的argmin操作,是没办法进行梯度回传的,所以在更新梯度的时候需要像红色箭头一样,进行一个梯度连接,让生成器decoder的input梯度,直接回传到编码器的output梯度。这一步的实现可以在代码中一行解决,具体在后面讲述

实验

这里进行一下代码讲解:
这里进行一下VQ的代码讲解:

class VectorQuantizer(nn.Module):
    """
    Reference:
    [1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py
    """
    def __init__(self,
                 num_embeddings: int,
                 embedding_dim: int,
                 beta: float = 0.25):
        super(VectorQuantizer, self).__init__()
        self.K = num_embeddings		# embedding 的数量
        self.D = embedding_dim		# embedding 维度
        self.beta = beta			# commitment loss的beta

        self.embedding = nn.Embedding(self.K, self.D)		# 创建dictionary
        self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K)	# 初始化

    def forward(self, latents: Tensor) -> Tensor:
        latents = latents.permute(0, 2, 3, 1).contiguous()  # [B x D x H x W] -> [B x H x W x D]
        latents_shape = latents.shape
        flat_latents = latents.view(-1, self.D)  # [BHW x D]

        # Compute L2 distance between latents and embedding weights
        dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \
               torch.sum(self.embedding.weight ** 2, dim=1) - \
               2 * torch.matmul(flat_latents, self.embedding.weight.t())  # [BHW x K]

        # Get the encoding that has the min distance
        encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1)  # [BHW, 1]

        # Convert to one-hot encodings
        device = latents.device
        encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device)
        encoding_one_hot.scatter_(1, encoding_inds, 1)  # [BHW x K]

        # Quantize the latents
        quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight)  # [BHW, D]
        quantized_latents = quantized_latents.view(latents_shape)  # [B x H x W x D]

        # Compute the VQ Losses
        # 两个l2损失直接用mse_loss进行实现
        commitment_loss = F.mse_loss(quantized_latents.detach(), latents)
        embedding_loss = F.mse_loss(quantized_latents, latents.detach())

        vq_loss = commitment_loss * self.beta + embedding_loss

        # Add the residue back to the latents
        # 这个地方用了一个巧妙的方法实现了梯度的连接回传,在数值上进行了变化但是梯度上进行了保留
        quantized_latents = latents + (quantized_latents - latents).detach()

        return quantized_latents.permute(0, 3, 1, 2).contiguous(), vq_loss  # [B x D x H x W]
    

思考:VQVAE在生成任务上相比于GAN,强在哪里?

很直观的,VQVAE将图像学习到一个嵌入空间中,语义更加明确,可解释性更强,且同类信息可以通过对特征进行一些更改来进行多元化的生成,这是GAN(原教旨主义)没办法解决的。


看完VQ- VAE,再回看之前的BEiT和DALL-E方法


DALL-E: Zero-Shot Text-to-Image Generation

DALL-E其实是VQ- VAE的一个改进版本,此外,将其应用到了多模态领域来进行跨模态生成。DALL- E经过MSCOCO上的预训练,在生成任务上有非常不错的效果,后续也进行了更多次迭代,DALL-E3也是目前数一数二的高分辨率图像生成模型。当然我不是做跨模态生成的,主要学习的是一些可以用在其他领域的思想,所以这里主要介绍一下DALL-E中具体tokenize的方法:

具体方法

首先本文的基本目标是学习一个能将图像和文本的tokens建模成一个单一的数据流的自回归transformer。如果直接用像素来作为image的token,在高分辨率图像中会带来巨大的内存消耗。此外,有一些方法基于似然函数的目标函数进行学习,但这种目标函数更加关注像素间的这种短程关系,仅仅关注了高频细节,但忽略了低频结构

本文的方法也是基于这两点提出的。本文的方法是一个两阶段的方法

  • 阶段一:学习了一个dVAE来将图像从256*256的RGB图像压缩成32*32的image token,利用了VQ的思想,这32*32的token每一个都是8192种取值的其中一个VQ-VAE中的one-hot / dictionary embedding)。这么做可以让整个图像内容的大小缩放到原本的1/192
    下图是dVAE的基本效果展示,作者承认会有一些细节丢失,但是生成图像依旧能够进行整体辨认。 dVAE效果

  • 阶段二:通过连接操作,将256维的BPE编码的文本token和32*32的图像token进行连接,然后一起喂给自回归的transformer中来对联合分布进行训练。(stage 2不是我主要学习的部份)

最终的过程可以被看作一个最大化图像 x x x,描述 y y y以及RGB图像的token z z z的分布联合相似性证据下界(ELB)的过程:
p θ , ψ ( x , y , z ) = p θ ( x ∣ y , z ) p ψ ( y , z ) p_{\theta,\psi}(x,y,z)=p_\theta(x\vert y,z)p_\psi(y,z) pθ,ψ(x,y,z)=pθ(xy,z)pψ(y,z)
这个式子有下界:
ln ⁡ p θ , ψ ( x , y ) ≥ E z ∼ q ϕ ( z ∣ x ) ( ln ⁡ p θ ( x ∣ y , z ) − β D K L ( q ϕ ( y , z ∣ x ) , p ψ ( y , z ) ) ) \ln p_{\theta,\psi}(x,y)\geq\mathop\mathbb E\limits_{z\sim q_\phi(z|x)}(\ln p_\theta(x|y,z)-\\\beta D_{KL}(q_\phi(y,z|x),p_\psi(y,z))) lnpθ,ψ(x,y)zqϕ(zx)E(lnpθ(xy,z)βDKL(qϕ(y,zx),pψ(y,z)))

一下子有好多的公式,这里来进行一下解释。

首先有输入图像为 x x x,作者假设 y y y对于 x x x在给定token z z z时条件独立的
q ϕ q_\phi qϕ是dVAE基于输入 x x x生成的32*32的image token的概率分布
p θ p_\theta pθ表示的是从dVAE的image token中进行生成的RGB图像
p ψ p_\psi pψ表示的是经过transformer生成的图像文本token的联合分布(这个transformer的作用还是没看明白,看看后面是否会有解释)

阶段一:Codebook学习

本文视觉的学习首先采用了VQ-VAE2的结构,VQ-VAE2相比于VQ-VAE,增加了层次结构。但是核心的VQ是一样了,也就是说需要学习一个Codebook作为描绘语义的 dictionary。而DALL-E提出的dVAE结构,将原本的图像变为来8192通道的32*32token。stage1是在视觉模态上的操作,原文真的是一点公式不给,但是在别人的学习记录看到大概是这样一个优化目标(这不就是VQVAE):
在这里插入图片描述
(要组会了来不及了,closeai在DALL-E上面给的东西太少了,这段坑以后再填)


BEIT: BERT Pre-Training of Image Transformers

上一篇是一个基于Pyramid ViT为多模态编码器的模型结构,这一篇则是以语言模型BERT为编码器的结构。

Intro

整个的motivation是基于BERT的,BERT的训练方法是通过MLM的方法来实现的。但是直接用BERT-style的模型架构在视觉信息上进行训练是不行的,原因如下:

  1. 对于每个token(patch),并没有视觉信息的vocabulary库与其进行对应(即不能精确到语义)
  2. 如果采用像素级的恢复任务会导致模型浪费其建模能力在预训练时处理短程依赖关系和高频细节的问题上(即细粒度的预训练往往会消耗更多的时间,因为很难去学像素级别的关系)。

于是本文提出了基于Image Transformer的双向编码器结构,基本的结构如下:
BEiT
在进行预训练前,先学习一个image tokenizer,这和COTS很像,看看这里有没有具体的方法。在预训练的过程中,图像被分为了两个view:一个是tokens(用于重建),一个是patches(用于进行ViT训练)。
简化的流程如下:首先先将一部份的Visual Tokens给掩蔽掉,然后将对应的patches用一个特殊的embedding[M]掩蔽。(也就是说,虽然是两个views,但是都对齐进行了掩蔽)

具体方法

视觉特征提取

如前文所讲,视觉的图像进行了两种处理,第一种是基于DALL- E中利用VQVAE的tokenize方法,将图像划分为token。现在再看这段内容终于明白了,首先是先将图像通过卷积等方法进行下采样,如原本256*256的图像可以经过卷积变成32*32的特征图

通道数就是我们想要的嵌入维度数,如我的embedding space想要512维的,我就进行512通道的卷积。
这里再重新区分一下embedding num和embedding dim,因为很多文章的说法经常会混淆,以512维特征,8192个值的字典为例:

  • 在计算距离做最近邻查找时,依靠的是这512维特征的欧式距离/余弦距离,查找范围是8192个512维的字典向量
  • 查找完后,转化为one-hot时转化的是这8192个值的其中一个的key值
  • 在decoding阶段用的是字典的value

重要的事情说三遍

在BEiT以及COTS中,我们实际使用的是这个KEY值作为token!!!
在BEiT以及COTS中,我们实际使用的是这个KEY值作为token!!!
在BEiT以及COTS中,我们实际使用的是这个KEY值作为token!!!

而下采样的特征图中,每一个像素位置上其实都会有这么一个key值,也就是图中的token。

此外还有最基本的patch方法,本文将原始图像转换为了14*14的patch和14*14的token序列。

骨干网络:Image Transformer

和ViT一样,直接用了标准的Transformer结构作为骨干网络。Transformer的输入直接用image patches,这些patches然后会被一个线性层映射到patch embedding中( x i p → E x i p x_i^p\rarr Ex_i^p xipExip),此外在前面加了一个special token来进行全局学习,在最后加入一个position Embedding。前向过程和transformer基本一致,最后生成的向量表示为 H L = [ h [ s ] L , h 1 L , … , h N L ] H^L=[h_{[s]}^L,h_1^L,\dots,h_N^L] HL=[h[s]L,h1L,,hNL]

MIM masked image modeling

重建的工作首先也是通过一个softmax+分类头将输出的特征进行一个分类操作,然后将分类头和本身被掩蔽的token去做一个负对数似然的损失构建,目标函数可以表示为如下:
MIM objective

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值