变分自编码器VAE实现MNIST数据集生成by Pytorch

参考:
原文:Auto-Encoding Variational Bayes

Recap自编码器:

自编码器中,需要输入一个原始图片,原始图片经过编码之后得到一个隐向量,隐向量解码产生原图片对应的图片。在这种情况下,只能生成原图片对应的图片而无法任意生成新的图片,因为隐向量都是原始图片确定的。

变分自编码器VAE

引入变分自编码器(Variational autoencoder)可以在遵循某一分布下随机产生一些隐向量来生成与原始图片不相同的图片,而不需要预先给定原始图片。为了达到这个目的,需要在编码过程增加限制,使得生成的隐向量能够粗略地遵循标准正态分布。
实际情况下,需要在模型的准确率与隐向量服从标准正态分布之间做一个权衡。模型的准确率就是指解码器生成的图片与原图片的相似程度;隐向量分布采用KL散度来衡量与标准正态分布之间的误差。两部分误差之和作为总体的误差来优化。

这里VAE使用了重参数化这个技巧来KL散度的计算。编码器不再是生成一个隐向量,而是生成正态分布的均值和标准差(若是多维正态分布,会有多个均值和标准差),然后根据这两个统计量下的分布抽样生成隐含向量。因为我们想要使得隐含向量服从标准正态分布,即均值为0,标准差为1,通过优化KL散度来使得分布逼近标准正态分布。
在这里插入图片描述
同理,解码器阶段,根据给定的隐变量 z z z来生成多元正态分布的均值 μ x 1 , μ x 2 \mu_{x1},\mu_{x2} μx1,μx2标准差 σ z 1 , σ z 2 \sigma_{z1},\sigma_{z2} σz1,σz2,根据该分布抽样生成数值 x 1 , x 2 x_{1},x_{2} x1,x2
在这里插入图片描述
将编码器和解码器综合在一起:
在这里插入图片描述
设编码器的概率分布为 q ϕ ( z ∣ x ) q_{\phi }(z|x) qϕ(zx),解码器的概率分布为 p θ ( x ∣ z ) p_{\theta}(x|z) pθ(xz)

误差推导:
在这里插入图片描述在这里插入图片描述
这里需要用到正态分布之间的KL散度,直接给出公式,推导见参考文献:
单元正态分布:
K L ( μ 1 , μ 2 , σ 1 , σ 2 ) = log ⁡ σ 2 σ 1 + σ 1 2 + ( μ 1 − μ 2 ) 2 2 σ 2 − 1 2 KL(\mu_{1},\mu_{2},\sigma_{1},\sigma_{2})=\log{\frac{\sigma_{2}}{\sigma_{1}}}+\frac{\sigma_{1}^{2}+(\mu_{1}-\mu_{2})^{2}}{2\sigma^{2}}-\frac{1}{2} KL(μ1,μ2,σ1,σ2)=logσ1σ2+2σ2σ12+(μ1μ2)221
n n n元正态分布:
K L ( μ 1 , μ 2 , Σ 1 , Σ 2 ) = 1 2 [ log ⁡ det ⁡ Σ 2 det ⁡ Σ 1 − n + t r ( Σ 2 − 1 Σ 1 ) + ( μ 2 − μ 1 ) T Σ 2 − 1 ( μ 2 − μ 1 ) ] KL(\bm{\mu}_{1},\bm{\mu}_{2},\Sigma_{1},\Sigma_{2})=\frac{1}{2}[\log{\frac{\det{\Sigma_{2}}}{\det{\Sigma_{1}}}}-n+tr(\Sigma_{2}^{-1}\Sigma_{1})+(\bm{\mu}_{2}-\bm{\mu}_{1})^{T}\Sigma_{2}^{-1}(\bm{\mu}_{2}-\bm{\mu}_{1})] KL(μ1,μ2,Σ1,Σ2)=21[logdetΣ1detΣ2n+tr(Σ21Σ1)+(μ2μ1)TΣ21(μ2μ1)]
由此可得到:
− D K L ( q ( z ∣ x i ) ∣ ∣ p ( z ) ) = 1 2 ∑ j = 1 J ( 1 + log ⁡ ( ( σ z j ( i ) ) 2 ) − ( μ z j ( i ) ) 2 − ( σ z j ( i ) ) 2 ) -D_{KL}(q(z|x^{i})||p(z))=\frac{1}{2}\sum_{j=1}^{J}(1+\log{((\sigma_{zj}^{(i)}})^{2})-(\mu_{zj}^{(i)})^{2}-(\sigma_{zj}^{(i)})^{2}) DKL(q(zxi)p(z))=21j=1J(1+log((σzj(i))2)(μzj(i))2(σzj(i))2)
通过从分布 q ( z ∣ x ( i ) ) q(z|x^{(i)}) q(zx(i))抽样来近似 E q ( z ∣ x ( i ) ) \mathbb{E}_{q(z|x^{(i)})} Eq(zx(i))。抽样 L L L次,得到 z ( i , l ) , l = 1 , 2 , . . . , L z^{(i,l)},l=1,2,...,L z(i,l),l=1,2,...,L, L L L通常非常小,通常取1
E q ( z ∣ x ( i ) ) ( log ⁡ ( p ( x ( i ) ∣ z ) ) ) = 1 L ∑ l = 1 L log ⁡ p ( x ( i ) ∣ z ( i , l ) ) = 1 L ∑ l = 1 L ∑ j = 1 D 1 2 log ⁡ σ x j 2 + ( x j i − μ x j ) 2 σ x j 2 \mathbb{E}_{q(z|x^{(i)})}(\log{(p(x^{(i)}|z))})=\frac{1}{L}\sum_{l=1}^{L}\log{p(x^{(i)}|z^{(i,l)})}=\frac{1}{L}\sum_{l=1}^{L}\sum_{j=1}^{D}\frac{1}{2}\log{\sigma_{xj}^{2}}+\frac{(x^{i}_{j}-\mu_{xj})}{2\sigma_{xj}^{2}} Eq(zx(i))(log(p(x(i)z)))=L1l=1Llogp(x(i)z(i,l))=L1l=1Lj=1D21logσxj2+2σxj2(xjiμxj)
其中 D D D代表样本 x ( i ) x^{(i)} x(i)的维度,每个数 x j ( i ) x^{(i)}_{j} xj(i)都对应一个正态分布 N ( μ x j , σ x j 2 ) \mathcal{N}(\mu_{xj},\sigma^{2}_{xj}) N(μxj,σxj2)

Pytorch实现MNIST数据集生成

在本实例中,生成器最后的输出不是均值和方差,而是图片向量。所以重构误差看做为生成图片和原始图片的误差。在这里使用的是binary cross entropy,即BCE误差,因为图片中的值都是(0,1)。当然也可以使用平方误差。
B C E = − ∑ i = 1 n ∑ j = 1 d [ y j ( i ) log ⁡ x j ( i ) + ( 1 − y j ( i ) ) log ⁡ ( 1 − x j ( i ) ) ] BCE=-\sum_{i=1}^{n}\sum_{j=1}^{d}[y^{(i)}_{j}\log{x_{j}^{(i)}}+(1-y^{(i)}_{j})\log{(1-x_{j}^{(i)})}] BCE=i=1nj=1d[yj(i)logxj(i)+(1yj(i))log(1xj(i))]

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image


def loss_function(recon_x, x, mu, logvar):
    """
    :param recon_x: generated image
    :param x: original image
    :param mu: latent mean of z
    :param logvar: latent log variance of z
    """
    BCE_loss = nn.BCELoss(reduction='sum')
    reconstruction_loss = BCE_loss(recon_x, x)
    KL_divergence = -0.5 * torch.sum(1+logvar-torch.exp(logvar)-mu**2)
    #KLD_ele = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    #KLD = torch.sum(KLD_ele).mul_(-0.5)
    print(reconstruction_loss, KL_divergence)

    return reconstruction_loss + KL_divergence


class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(784, 400)
        self.fc2_mean = nn.Linear(400, 20)
        self.fc2_logvar = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc2_mean(h1), self.fc2_logvar(h1)

    def reparametrization(self, mu, logvar):
        # sigma = 0.5*exp(log(sigma^2))= 0.5*exp(log(var))
        std = 0.5 * torch.exp(logvar)
        # N(mu, std^2) = N(0, 1) * std + mu
        z = torch.randn(std.size()) * std + mu
        return z

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrization(mu, logvar)
        return self.decode(z), mu, logvar


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

vae = VAE()
optimizer = torch.optim.Adam(vae.parameters(), lr=0.0003)

# Training
def train(epoch):
    vae.train()
    all_loss = 0.
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to('cpu'), targets.to('cpu')
        real_imgs = torch.flatten(inputs, start_dim=1)

        # Train Discriminator
        gen_imgs, mu, logvar = vae(real_imgs)
        loss = loss_function(gen_imgs, real_imgs, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        all_loss += loss.item()
        print('Epoch {}, loss: {:.6f}'.format(epoch, all_loss/(batch_idx+1)))
        # Save generated images for every epoch
    fake_images = gen_imgs.view(-1, 1, 28, 28)
    save_image(fake_images, 'MNIST_FAKE/fake_images-{}.png'.format(epoch + 1))



for epoch in range(20):
    train(epoch)

torch.save(vae.state_dict(), './vae.pth')

运行上述代码20轮所产生的的图片:
在这里插入图片描述
VAE和自编码器有一样的缺点,根据均平方误差计算的图片会比较模糊,之后出现的对抗生成网络则解决了这个问题。

好的,以下是变分自动编码器(VAE)的 PyTorch 实现代码: ```python import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torchvision from torchvision import transforms class VAE(nn.Module): def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20): super(VAE, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.latent_dim = latent_dim # Encoder layers self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc21 = nn.Linear(hidden_dim, latent_dim) self.fc22 = nn.Linear(hidden_dim, latent_dim) # Decoder layers self.fc3 = nn.Linear(latent_dim, hidden_dim) self.fc4 = nn.Linear(hidden_dim, input_dim) def encode(self, x): h1 = F.relu(self.fc1(x)) mu = self.fc21(h1) logvar = self.fc22(h1) return mu, logvar def reparameterize(self, mu, logvar): std = torch.exp(0.5*logvar) eps = torch.randn_like(std) return mu + eps*std def decode(self, z): h3 = F.relu(self.fc3(z)) return torch.sigmoid(self.fc4(h3)) def forward(self, x): mu, logvar = self.encode(x.view(-1, self.input_dim)) z = self.reparameterize(mu, logvar) return self.decode(z), mu, logvar # Define loss function def loss_function(recon_x, x, mu, logvar): BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum') KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return BCE + KLD # Load MNIST dataset transform = transforms.Compose([transforms.ToTensor()]) train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True) # Initialize model and optimizer model = VAE() optimizer = optim.Adam(model.parameters(), lr=1e-3) # Train model for epoch in range(1, 11): train_loss = 0 for batch_idx, (data, _) in enumerate(train_loader): optimizer.zero_grad() recon_batch, mu, logvar = model(data) loss = loss_function(recon_batch, data, mu, logvar) loss.backward() train_loss += loss.item() optimizer.step() if batch_idx % 100 == 0: print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item() / len(data))) print('====> Epoch: {} Average loss: {:.4f}'.format( epoch, train_loss / len(train_loader.dataset))) ``` 这段代码定义了一个包含一个编码器和一个解码器的 VAE 模型。在训练过程中,模型会从 MNIST 数据集中读取图像数据,并使用 Adam 优化器来更新模型参数。模型在每个 epoch 结束后会输出平均损失。
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值