AIGC笔记--VAE模型的搭建

目录

1--VAE模型

2--代码实例


1--VAE模型

简单介绍:

        通过一个 encoder 将图片映射到标准分布(均值和方差),从映射的标准分布中随机采样一个样本,通过 decoder 重构图片;

        计算源图片和重构图片之间的损失,具体损失函数的推导可以参考:变分自编码器(VAE)

2--代码实例

简单的VAE模型搭建:

        Encoder 返回映射标准分布的均值和方差,从标准分布中随机采样,利用Decoder重构图片;

class VAE(nn.Module):
    def __init__(self, input_dim = 784, h_dim = 400, z_dim = 20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(input_dim, h_dim) # 28*28 → 784
        self.fc21 = nn.Linear(h_dim, z_dim) # 均值
        self.fc22 = nn.Linear(h_dim, z_dim) # 标准差
        self.fc3 = nn.Linear(z_dim, h_dim)
        self.fc4 = nn.Linear(h_dim, input_dim) # 784 → 28*28
        self.input_dim = input_dim

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1) # 均值、标准差

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        # z = mu + eps*std
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        # sigmoid: 0-1 之间,后边会用到 BCE loss 计算重构 loss(reconstruction loss)
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, self.input_dim))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

简单的损失计算:

x_reconst, mu, log_var = self.model(x)
# Compute reconstruction loss and kl divergence
reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
# Backprop and optimize
loss = reconst_loss + kl_div

完整可运行代码参考:VAE代码实例

  • 14
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值