AIGC笔记--VQVAE模型搭建

1--VQVAE模型

        VAE 模型生成的内容质量不高,原因可能在于将图片编码成连续变量(映射为标准分布),然而将图片编码成离散变量可能会更好(因为现实生活中习惯用离散变量来形容事物,例如人的高矮胖瘦等都是离散的;)

        VQVAE模型的三个关键模块:EncoderDecoderCodebook

        Encoder 将输入编码成特征向量,计算特征向量与 Codebook 中 Embedding 向量的相似性(L2距离),取最相似的 Embedding 向量作为特征向量的替代,并输入到 Decoder 中进行重构输入;

        VQVAE的损失函数包括源图片和重构图片的重构损失,以及 Codebook 中量化过程的量化损失 vq_loss;

        VQ-VAE详细介绍参考:轻松理解 VQ-VAE

2--简单代码实例

import torch
import torch.nn as nn
import torch.nn.functional as F

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
        self._commitment_cost = commitment_cost

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape
        
        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)
        
        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
            
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
        
        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)  # 论文中损失函数的第三项
        q_latent_loss = F.mse_loss(quantized, inputs.detach()) # 论文中损失函数的第二项
        loss = q_latent_loss + self._commitment_cost * e_latent_loss
        
        quantized = inputs + (quantized - inputs).detach() # 梯度复制
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

class VectorQuantizerEMA(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
        super(VectorQuantizerEMA, self).__init__()
        
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.normal_()
        self._commitment_cost = commitment_cost
        
        self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
        self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
        self._ema_w.data.normal_()
        
        self._decay = decay
        self._epsilon = epsilon

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape # B(256) H(8) W(8) C(64)
        
        # Flatten input BHWC -> BHW, C
        flat_input = inputs.view(-1, self._embedding_dim)
        
        # Calculate distances 计算与embedding space中所有embedding的距离
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
            
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) # 取最相似的embedding
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1) # 映射为 one-hot vector
        
        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) # 根据index使用embedding space对应的embedding
        
        # Use EMA to update the embedding vectors
        if self.training:
            self._ema_cluster_size = self._ema_cluster_size * self._decay + \
                                     (1 - self._decay) * torch.sum(encodings, 0)
            
            # Laplace smoothing of the cluster size
            n = torch.sum(self._ema_cluster_size.data)
            self._ema_cluster_size = (
                (self._ema_cluster_size + self._epsilon)
                / (n + self._num_embeddings * self._epsilon) * n) 
            
            dw = torch.matmul(encodings.t(), flat_input)
            self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw) 
            
            self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1)) # 论文中公式(8)
        
        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs) # 计算encoder输出(即inputs)和decoder输入(即quantized)之间的损失
        loss = self._commitment_cost * e_latent_loss
        
        # Straight Through Estimator
        quantized = inputs + (quantized - inputs).detach() # trick, 将decoder的输入对应的梯度复制,作为encoder的输出对应的梯度
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

class Residual(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
        super(Residual, self).__init__()
        self._block = nn.Sequential(
            nn.ReLU(True),
            nn.Conv2d(in_channels = in_channels,
                      out_channels = num_residual_hiddens,
                      kernel_size = 3, stride = 1, padding = 1, bias = False),
            nn.ReLU(True),
            nn.Conv2d(in_channels = num_residual_hiddens,
                      out_channels = num_hiddens,
                      kernel_size = 1, stride = 1, bias = False)
        )
    
    def forward(self, x):
        return x + self._block(x)

class ResidualStack(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(ResidualStack, self).__init__()
        self._num_residual_layers = num_residual_layers
        self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
                             for _ in range(self._num_residual_layers)])

    def forward(self, x):
        for i in range(self._num_residual_layers):
            x = self._layers[i](x)
        return F.relu(x)

class Encoder(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(Encoder, self).__init__()
        self._conv_1 = nn.Conv2d(in_channels = in_channels,
                                 out_channels = num_hiddens//2,
                                 kernel_size = 4,
                                 stride = 2, padding = 1)
        self._conv_2 = nn.Conv2d(in_channels = num_hiddens//2,
                                 out_channels = num_hiddens,
                                 kernel_size = 4,
                                 stride = 2, padding = 1)
        self._conv_3 = nn.Conv2d(in_channels = num_hiddens,
                                 out_channels = num_hiddens,
                                 kernel_size = 3,
                                 stride = 1, padding = 1)
        self._residual_stack = ResidualStack(in_channels = num_hiddens,
                                             num_hiddens = num_hiddens,
                                             num_residual_layers = num_residual_layers,
                                             num_residual_hiddens = num_residual_hiddens)

    def forward(self, inputs):
        x = self._conv_1(inputs)
        x = F.relu(x)
        x = self._conv_2(x)
        x = F.relu(x)
        x = self._conv_3(x)
        return self._residual_stack(x)

class Decoder(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(Decoder, self).__init__()
        
        self._conv_1 = nn.Conv2d(in_channels=in_channels,
                                 out_channels=num_hiddens,
                                 kernel_size=3, 
                                 stride=1, padding=1)
        self._residual_stack = ResidualStack(in_channels=num_hiddens,
                                             num_hiddens=num_hiddens,
                                             num_residual_layers=num_residual_layers,
                                             num_residual_hiddens=num_residual_hiddens)
        self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens, 
                                                out_channels=num_hiddens//2,
                                                kernel_size=4, 
                                                stride=2, padding=1)
        self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens//2, 
                                                out_channels=3,
                                                kernel_size=4, 
                                                stride=2, padding=1)

    def forward(self, inputs):
        x = self._conv_1(inputs)
        
        x = self._residual_stack(x)
        
        x = self._conv_trans_1(x)
        x = F.relu(x)
        
        return self._conv_trans_2(x)

class Model(nn.Module):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens, 
                 num_embeddings, embedding_dim, commitment_cost, decay=0):
        super(Model, self).__init__()
        
        self._encoder = Encoder(3, num_hiddens,
                                num_residual_layers, 
                                num_residual_hiddens)
        self._pre_vq_conv = nn.Conv2d(in_channels = num_hiddens, 
                                      out_channels = embedding_dim,
                                      kernel_size = 1, 
                                      stride = 1)
        if decay > 0.0:
            self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim, 
                                              commitment_cost, decay)
        else:
            self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
                                           commitment_cost)
        self._decoder = Decoder(embedding_dim,
                                num_hiddens, 
                                num_residual_layers, 
                                num_residual_hiddens)

    def forward(self, x): 
        # x.shape: B(256) C(3) H(32) W(32)
        z = self._encoder(x)
        z = self._pre_vq_conv(z)
        loss, quantized, perplexity, _ = self._vq_vae(z)
        x_recon = self._decoder(quantized) # decoder解码还原图像 B(256) C(3) H(32) W(32)

        return loss, x_recon, perplexity

完整代码参考:liujf69/VQ-VAE

3--部分细节解读:

重构损失计算:

        计算源图像和重构图像的MSE损失

vq_loss, data_recon, perplexity = self.model(data)
recon_error = F.mse_loss(data_recon, data) / self.data_variance 

VQ量化损失计算:

        inputs表示Encoder的输出,quantized是Codebook中与 inputs 最接近的向量;

# Loss
e_latent_loss = F.mse_loss(quantized.detach(), inputs)  # 论文中损失函数的第三项
q_latent_loss = F.mse_loss(quantized, inputs.detach()) # 论文中损失函数的第二项
loss = q_latent_loss + self._commitment_cost * e_latent_loss

Decoder的梯度复制到Encoder中:inputs是Encoder的输出,quantized是Decoder的输入;

quantized = inputs + (quantized - inputs).detach() # 梯度复制

  • 9
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: 华为HCIP精品课程笔记-Wakin是一本非常有价值的学习资料。该笔记通过详细介绍HCIP认证的相关知识点和实用技能,为学生提供了一种高效学习的方式。 首先,华为HCIP精品课程笔记-Wakin涵盖了HCIP认证考试需要掌握的全部知识点。对于想要通过HCIP考试的学生来说,这本笔记提供了一个简明扼要的指南,可以帮助他们系统地学习和复习相关知识。笔记中的内容分为多个章节,涵盖了网络技术、路由器和交换机、路由控制、IP多播、IPv6和MPLS等重要主题。每个主题都有详细的解释和示例,帮助学生更好地理解和掌握。 其次,华为HCIP精品课程笔记-Wakin还包含了一些实用技能和案例分析。这些内容能够帮助学生更好地理解和应用所学知识。对于在实际工作中需要应用HCIP技能的人来说,这本笔记提供了一些宝贵的经验和建议。 此外,华为HCIP精品课程笔记-Wakin还提供了一些习题和练习题,可以帮助学生检验自己的学习成果。通过对这些习题的练习,学生可以更好地了解自己的薄弱点,并进行有针对性的复习和提高。 总之,华为HCIP精品课程笔记-Wakin是一本非常实用和有价值的学习资料。它提供了一种高效的学习方式,帮助学生系统地掌握和应用HCIP认证的相关知识和技能。我强烈推荐这本笔记给所有想要通过HCIP考试或者在实际工作中应用HCIP技能的人。 ### 回答2: 《华为HCIP精品课程笔记-wakin》是一本非常有价值的学习资料。这本书由华为公司精心编撰而成,旨在帮助学员高效学习和掌握华为认证网络工程师(HCIP)认证所需的知识和技能。 这本笔记深入浅出地介绍了HCIP认证相关的重要概念、原理和应用。其中包括了网络架构设计、路由与交换技术、安全技术、无线网络技术等内容。每个主题都有详细的解释、示意图和实例,使读者能够更好地理解和应用知识。 除了内容丰富全面外,这本笔记还具有一些独特的优点。首先,它采用了华为独有的学习方法,系统化地梳理了知识结构,使读者能够更加有条理地学习。其次,每个章节都附带了重点整理的要点,方便读者快速回顾和温习。此外,为了帮助读者更好地理解,笔记还提供了一些实验和实际案例,使学习更加实践性和深入。 通过学习《华为HCIP精品课程笔记-wakin》,读者将能够全面了解和掌握HCIP认证所需的知识和技能。这些知识和技能不仅适用于工作中的网络工程师,也对于其他相关岗位的人员有很大的帮助。无论是对于初学者还是对于有一定经验的人来说,这本书都是一本非常实用的学习资料。强烈推荐给所有对网络工程感兴趣的人士。 ### 回答3: 华为HCIP精品课程笔记-Wakin是一份非常有价值的学习资料。这份笔记由华为公司的专业培训师Wakin编写,对于想要学习和提升HCIP认证的人来说,是一份非常实用的参考资料。 Wakin在这份笔记中,详细地介绍了HCIP的知识点和考试重点。他从网络基础、路由交换、安全技术、无线网络等多个方面入手,深入浅出地解释了每个知识点的概念和原理。在每个章节中,Wakin都给出了一些实际案例和实验,帮助我们更好地理解和应用所学内容。 此外,Wakin在笔记中还提供了一些学习方法和技巧。他建议我们在学习过程中,要注重实践,通过自己动手实验和配置设备来加深对知识的理解。他还推荐了一些学习资源和参考书籍,帮助我们更好地补充和扩展所学知识。 总的来说,华为HCIP精品课程笔记-Wakin非常全面且易于理解。无论是准备HCIP认证考试的人,还是想要进一步提升自己网络技术的人,都可以从中受益匪浅。我相信,只要认真学习并灵活运用这份笔记中的知识,就能够在网络领域中取得更好的成绩和发展。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值