论文Neural Discrete Representation Learning(VQ-VAE)详解(PyTorch)

https://arxiv.org/pdf/1711.00937v2.pdf(论文下载链接)

github代码下载链接1

github代码下载链接2

        之所以将VQ-VAE(Vector Quantised Variational AutoEncoder)论文,主要是为讲解后面两篇论文做准备,VQ-VAE不管是视频还是博客,都有人在讲解,但是这里也做一个总结,以衔接后面的两篇论文讲解,关于VAE(Variational AutoEncoder)相关的论文比较多,并且其中涉及的数学原理以及推导也比较多,导致我们在阅读VAE方法的时候可能存在较多的困惑,自己在看的过程中也遇到了较多的困惑,但是还是准备做一个总结。

目录

一.提出目的和方法

1.提出目的

2.提出方法

二.VQ-VAE贡献点

三.VQ-VAE具体方法

1.离散化隐藏向量

四.VQ-VAE学习方式

 核心代码实现 

五.实验比较


一.提出目的和方法

1.提出目的

传统的VAE(变分自编码器)在隐空间中使用连续分布,导致生成的隐变量难以进行有效的离散化表示(如用于序列建模或强化学习)。

2.提出方法

VQ-VAE提出了一种​​离散隐变量​​的自编码方法,通过​​向量量化(Vector Quantization, VQ)​​ 实现隐空间的离散化,从而提升表征的可解释性和生成质量具体方法:编码器网络输出离散而非连续代码;且先验分布是动态学习而非静态预设。为学习离散潜在表征,融入了向量量化(VQ)的核心思想。采用VQ方法使模型能够规避VAE框架中常见的"后验坍塌"问题(即当潜在变量与强大的自回归解码器结合时被忽略)

二.VQ-VAE贡献点

  1. 提出VQ-VAE模型:该模型结构简单,采用离散潜在变量,既不会出现"后验坍塌"问题,也不存在方差异常;
验证离散模型性能:证明离散潜在变量模型( VQ-VAE )在对数似然指标上可达到与连续型模型相当的水平;
实现跨模态高质量生成:当与强大先验模型结合时,在语音合成、视频生成等多种应用中均能产生连贯且高质量的样本;
突破无监督语音学习:首次通过原始语音无监督学习到语言结构,并实现了说话人音色转换的实际应用。

与本研究最相关的工作当属变分自编码器(VAEs)。VAE包含以下核心组件:
        1)编码器网络:用于参数化离散潜在随机变量z的后验分布q(z|x),其中x为输入数据;
        2)先验分布p(z)
        3)解码器网络:建立输入数据条件分布p(x|z)

传统VAE通常假设后验分布与先验分布均为对角协方差的正态分布,这种设定可利用高斯重参数化技巧。现有扩展方法包括:

        自回归先验与后验模型
        标准化流
        逆自回归后验

本研究提出的VQ-VAE创新性地采用离散潜在变量,其训练方法受向量量化(VQ)启发。该模型中:
        1)后验与先验分布均为类别分布;
        2)从分布中采样的离散值作为嵌入表的索引;
        3)检索到的嵌入向量将作为解码器网络的输入。

三.VQ-VAE具体方法

1.离散化隐藏向量

注:也就是这里计算嵌入空间和编码器输出向量Ze(x)之间的距离,找到最小距离的索引K,然后下面将其转换为one-hot编码格式。 

转换为one-hot编码格式之后。通过为1的位置获得对应的离散向量。 为了更好的理解这个过程,使用下面的图来给大家表示一下:(建议结合代码看)

四.VQ-VAE学习方式

虽然方程2没有明确定义的梯度,但本文采用类似于​​直通估计器(straight-through estimator)​​的方法来近似梯度,即​​直接将解码器输入 zq(x) 的梯度复制到编码器输出 ze(x)​​。 

# TODO gradient copy trick (Add the residue back to the latents)
quantized_latents = latents + (quantized_latents - latents).detach()
计算量化误差(zq − ze),但通过 .detach() 断开梯度回传,确保这部分不会影响编码器的梯度。
  • 前向传播时,等价于直接返回 zq​(因为 ze​+(zq​−ze​)=zq​)。
  • 反向传播时,由于右侧项被 detach(),梯度会直接通过左侧的 latents(即 ze​)回传,​​相当于将 zq​ 的梯度复制给了 ze​​​。

 核心代码实现 

class VectorQuantizer(nn.Module):
    """
    Reference:
    [1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py
    """
    def __init__(self,
                 num_embeddings: int,
                 embedding_dim: int,
                 beta: float = 0.25):
        super(VectorQuantizer, self).__init__()
        self.K = num_embeddings
        self.D = embedding_dim
        self.beta = beta
        #TODO 定义的嵌入向量e
        self.embedding = nn.Embedding(self.K, self.D)
        self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K)

    def forward(self, latents: Tensor) -> Tensor:
        #TODO 编码器输出的编码向量
        latents = latents.permute(0, 2, 3, 1).contiguous()  # [B x D x H x W] -> [B x H x W x D]
        latents_shape = latents.shape
        flat_latents = latents.view(-1, self.D)  # [BHW x D]

        # TODO 计算隐藏向量和嵌入向量权重之间的L2距离 Compute L2 distance between latents and embedding weights
        dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \
               torch.sum(self.embedding.weight ** 2, dim=1) - \
               2 * torch.matmul(flat_latents, self.embedding.weight.t())  # [BHW x K]

        # TODO 获得最小距离对应的索引 Get the encoding that has the min distance
        encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1)  # [BHW, 1]

        # TODO 将其索引转换为对应的one-hot编码 Convert to one-hot encodings
        device = latents.device
        encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device)
        encoding_one_hot.scatter_(1, encoding_inds, 1)  # [BHW x K]

        #TODO 获得离散化隐藏向量空间 Quantize the latents
        quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight)  # [BHW, D]
        quantized_latents = quantized_latents.view(latents_shape)  # [B x H x W x D]

        # TODO Compute the VQ Losses
        commitment_loss = F.mse_loss(quantized_latents.detach(), latents)
        embedding_loss = F.mse_loss(quantized_latents, latents.detach())

        vq_loss = commitment_loss * self.beta + embedding_loss

        # Add the residue back to the latents
        quantized_latents = latents + (quantized_latents - latents).detach()

        return quantized_latents.permute(0, 3, 1, 2).contiguous(), vq_loss  # [B x D x H x W]

五.实验比较

悄悄举手:若觉得文章有用,不妨留下一个小赞?(´▽`ʃƪ)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值