VQ矢量量化Python代码学习记录

代码来源于github开源库:https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py#L122

contiguous() 方法用于确保张量在内存中是连续存储的,当张量进行形状变换之后,有可能会导致存储顺序不再是连续的,这时就需要使用 .contiguous() 方法来重新安排存储顺序,以确保后续操作能够正确进行。

.embedding.weight.data.uniform_(a,b)用于把权重值限定在a和b之间

接下来定义VQ类,需要传入num_embedding,即码本中码字的数量;embedding_dim,即码字的维度。通过L2距离获得传入隐变量Latents和码本中码字的距离,以最近邻方法得到量化结果。 

class VectorQuantizer(nn.Module):
    """
    Reference:
    [1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py
    """
    def __init__(self,
                 num_embeddings: int,
                 embedding_dim: int,
                 beta: float = 0.25):
        super(VectorQuantizer, self).__init__()
        self.K = num_embeddings   # 10
        self.D = embedding_dim  # 1024
        self.beta = beta

        self.embedding = nn.Embedding(self.K, self.D) # 10*1024
        self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K) 

    def forward(self, latents: Tensor) -> Tensor:
        latents = latents.permute(0, 2, 3, 1).contiguous() 
         # [B x D x H x W] -> [B x H x W x D] [32*1024*3*3] ->  [32*3*3*1024]
        latents_shape = latents.shape
        flat_latents = latents.view(-1, self.D) 
         # [BHW x D] [288*1024]



        # Compute L2 distance between latents and embedding weights
        # 计算L2距离以得到码本中最近的码
        dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \
               torch.sum(self.embedding.weight ** 2, dim=1) - \
               2 * torch.matmul(flat_latents, self.embedding.weight.t()) 
               # [BHW x K] [288 x 10]

        # Get the encoding that has the min distance
        # 得到L2范式中最小的索引,即码在码本中的位置
        encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1)  # [288,1]

        # Convert to one-hot encodings
        device = latents.device

        # 得到和索引相同shape的全零矩阵
        encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device) # [288 x 10]

        # 将全零矩阵中欧氏距离最小处的索引赋值为1
        encoding_one_hot.scatter_(1, encoding_inds, 1)  # [288 x 10]

        # Quantize the latents 相乘,乘0为0,乘1保持不变,取出码本中的码,即量化后的向量
        quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight)  # [BHW, D]
        quantized_latents = quantized_latents.view(latents_shape)  # [B x H x W x D]

        # Compute the VQ Losses
        commitment_loss = F.mse_loss(quantized_latents.detach(), latents)
        embedding_loss = F.mse_loss(quantized_latents, latents.detach())

        vq_loss = commitment_loss * self.beta + embedding_loss

        # 保证量化函数和潜变量的梯度相同
        quantized_latents = latents + (quantized_latents - latents).detach()

        return quantized_latents.permute(0, 3, 1, 2).contiguous(), vq_loss  # [B x D x H x W]

  • 20
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值