如何使用变分自动编码器从潜在空间进行采样(How to Sample From Latent Space With Variational Autoencoder)

img

与传统的自动编码器一样,VAE架构有两个部分:编码器和解码器。传统的AE模型将输入映射到一个潜在空间向量,并从这个向量重建输出。

VAE将输入映射到一个多元正态分布(multivariate normal distribution)中(编码器输出每个潜在维度的均值和方差)。

由于VAE编码器产生一个分布,因此可以通过从该分布中采样并将采样的潜在向量传递给解码器来生成新数据。从生成的分布中采样以生成输出图像意味着VAE允许生成与输入数据相似但相同的新数据。

本文探讨了VAE体系结构的组件,并提供了几种使用VAE模型生成新图像(采样)的方法。所有的代码都可以在Google Colab获得。

1、VAE模型实现

img

通过最小化重构损失(例如BCE或MSE)来训练AE模型。

自动编码器和变分自动编码器都有两个部分:编码器和解码器。AE的编码器神经网络学习将每个图像映射到潜在空间中的单个向量,解码器学习从编码器的潜在向量重建原始图像。

img

通过最小化重构损失和KL-散度来训练VAE模型。

VAE的编码器神经网络输出的参数定义了潜在空间的每个维度的概率分布(多元分布)。对于每个输入,编码器为潜在空间的每个维度产生平均值和方差。

输出均值和方差用于定义多元高斯分布。解码器神经网络与AE模型相同。

① VAE损失

训练VAE模型的目标是最大化从提供的潜在向量生成真实图像的可能性。在训练过程中,VAE模型将两个损失最小化。

  • reconstruction loss:输入图像和解码器输出之间的差异。
  • Kullback-Leibler散度损失(KL散度是两个概率分布之间的统计距离):编码器输出的概率分布与先验分布(标准正态分布)之间的距离,有助于正则化潜在空间。
② Reconstruction Loss

常见的重构损失有二院交叉熵(binary cross-entropy,BCE)和均方误差(mean squared error,MSE)。本文中,我将使用MNIST数据集进行演示。MNIST图像只有一个通道,像素值在0到1之间。

reconstruction_loss = nn.BCELoss(reduction='sum')
③ Kullback-Leibler Divergence

如上所述,KL散度评估两个分布之间的差异。注意它不具有距离的对称性质: K L ( P ∣ ∣ Q ) ! = K L ( Q ∣ ∣ P ) KL(P||Q)!=KL(Q||P) KL(P∣∣Q)!=KL(Q∣∣P)

需要比较的两个分布是:

  • 给定输入图像 x x x的编码器输出的潜在空间: q ( z ∣ x ) q(z|x) q(zx)
  • 潜在空间先验 p ( z ) p(z) p(z),它被假设为一个正态分布,在每个潜在空间维度 N ( 0 , 1 ) N(0,1) N(0,1)中均值为0,标准差为1。

这样的假设简化了KL散度的计算,并鼓励潜在空间遵循已知的、可管理的分布。

from torch.distributions.kl import kl_divergence
def kl_divergence_loss(z_dist):
    return kl_divergence(z_dist,
                         Normal(torch.zeros_like(z_dist.mean),
                                torch.ones_like(z_dist.stddev))
                         ).sum(-1).sum()
④ 编码器
class Encoder(nn.Module):
    def __init__(self, im_chan=1, output_chan=32, hidden_dim=16):
        super(Encoder, self).__init__()
        self.z_dim = output_chan

        self.encoder = nn.Sequential(
            self.init_conv_block(im_chan, hidden_dim),
            self.init_conv_block(hidden_dim, hidden_dim * 2),
            # double output_chan for mean and std with [output_chan] size
            self.init_conv_block(hidden_dim * 2, output_chan * 2, final_layer=True),
        )

    def init_conv_block(self, input_channels, output_channels, kernel_size=4, stride=2, padding=0, final_layer=False):
        layers = [
            nn.Conv2d(input_channels, output_channels,
                          kernel_size=kernel_size,
                          padding=padding,
                          stride=stride)
        ]
        if not final_layer:
            layers += [
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True)
            ]
        return nn.Sequential(*layers)

    def forward(self, image):
        encoder_pred = self.encoder(image)
        encoding = encoder_pred.view(len(encoder_pred), -1)
        mean = encoding[:, :self.z_dim]
        logvar = encoding[:, self.z_dim:]
        # encoding output representing standard deviation is interpreted as
        # the logarithm of the variance associated with the normal distribution
        # take the exponent to convert it to standard deviation
        return mean, torch.exp(logvar*0.5)
⑤ 解码器
class Decoder(nn.Module):
    def __init__(self, z_dim=32, im_chan=1, hidden_dim=64):
        super(Decoder, self).__init__()
        self.z_dim = z_dim
        self.decoder = nn.Sequential(
            self.init_conv_block(z_dim, hidden_dim * 4),
            self.init_conv_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self.init_conv_block(hidden_dim * 2, hidden_dim),
            self.init_conv_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def init_conv_block(self, input_channels, output_channels, kernel_size=3, stride=2, padding=0, final_layer=False):
        layers = [
            nn.ConvTranspose2d(input_channels, output_channels,
                               kernel_size=kernel_size,
                               stride=stride, padding=padding)
        ]
        if not final_layer:
            layers += [
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True)
            ]
        else:
            layers += [nn.Sigmoid()]
        return nn.Sequential(*layers)

    def forward(self, z):
        # Ensure the input latent vector z is correctly reshaped for the decoder
        x = z.view(-1, self.z_dim, 1, 1)
        # Pass the reshaped input through the decoder network
        return self.decoder(x)
⑥ VAE模型

要通过随机样本进行反向传播,您需要将随机样本的参数μ和σ移到函数之外,以允许通过参数进行梯度计算。这一步也被称为“重新参数化技巧”(“reparameterization trick”)。

在Pytorch中,您可以使用编码器的输出μ和σ创建一个正态分布,并使用resample()方法从中采样,该方法实现了重参数化技巧:它与torch.randn(z_dim)*stddev+mean相同。

class VAE(nn.Module):
  def __init__(self, z_dim=32, im_chan=1):
    super(VAE, self).__init__()
    self.z_dim = z_dim
    self.encoder = Encoder(im_chan, z_dim)
    self.decoder = Decoder(z_dim, im_chan)

  def forward(self, images):
    z_dist = Normal(self.encoder(images))
    # sample from distribution with reparametarazation trick
    z = z_dist.rsample()
    decoding = self.decoder(z)
    return decoding, z_dist
⑦ 训练VAE

加载MNIST训练和测试数据:

transform = transforms.Compose([transforms.ToTensor()])

# Download and load the MNIST training data
trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
train_loader = DataLoader(trainset, batch_size=64, shuffle=True)

# Download and load the MNIST test data
testset = datasets.MNIST('.', download=True, train=False, transform=transform)
test_loader = DataLoader(testset, batch_size=64, shuffle=True)

img

按照上图所示的VAE训练步骤创建一个训练循环:

def train_model(epochs=10, z_dim = 16):
  model = VAE(z_dim=z_dim).to(device)
  model_opt = torch.optim.Adam(model.parameters())
  for epoch in range(epochs):
      print(f"Epoch {epoch}")
      for images, step in tqdm(train_loader):
          images = images.to(device)
          model_opt.zero_grad()
          recon_images, encoding = model(images)
          loss = reconstruction_loss(recon_images, images)+ kl_divergence_loss(encoding)
          loss.backward()
          model_opt.step()
      show_images_grid(images.cpu(), title=f'Input images')
      show_images_grid(recon_images.cpu(), title=f'Reconstructed images')
  return model
z_dim = 8
vae = train_model(epochs=20, z_dim=z_dim)

img

img

⑧ 可视化潜在空间
def visualize_latent_space(model, data_loader, device, method='TSNE', num_samples=10000):
    model.eval()
    latents = []
    labels = []
    with torch.no_grad():
        for i, (data, label) in enumerate(data_loader):
          if len(latents) > num_samples:
            break
          mu, _ = model.encoder(data.to(device))
          latents.append(mu.cpu())
          labels.append(label.cpu())

    latents = torch.cat(latents, dim=0).numpy()
    labels = torch.cat(labels, dim=0).numpy()
    assert method in ['TSNE', 'UMAP'], 'method should be TSNE or UMAP'
    if method == 'TSNE':
        tsne = TSNE(n_components=2, verbose=1)
        tsne_results = tsne.fit_transform(latents)
        fig = px.scatter(tsne_results, x=0, y=1, color=labels, labels={'color': 'label'})
        fig.update_layout(title='VAE Latent Space with TSNE',
                          width=600,
                          height=600)
    elif method == 'UMAP':
        reducer = umap.UMAP()
        embedding = reducer.fit_transform(latents)
        fig = px.scatter(embedding, x=0, y=1, color=labels, labels={'color': 'label'})

        fig.update_layout(title='VAE Latent Space with UMAP',
                          width=600,
                          height=600
                          )

    fig.show()
visualize_latent_space(vae, train_loader,
                       device='cuda' if torch.cuda.is_available() else 'cpu',
                       method='UMAP', num_samples=10000)
img

用UMAP可视化VAE模型潜在空间。

2、用VAE采样

从变分自动编码器(VAE)中采样可以生成与训练过程中看到的相似的新数据,这是将VAE与传统AE体系结构区分开来的一个独特方面。

有几种方法可以从VAE中进行采样:

  • 后验采样(posterior sampling):从给定输入的后验分布中采样。
  • 先验采样(prior sampling):从假设标准正态多变量分布的潜在空间采样。这是可能的,因为假设(在VAE训练中使用)潜在变量是正态分布的。此方法不允许生成具有特定属性的数据(例如,从特定类生成数据)。
  • 插值(interpolation):在潜在空间两点之间插值可以揭示潜在空间变量的变化如何对应生成数据的变化。
  • 潜在维度遍历(traversal of latent dimensions):遍历潜在维度,数据的VAE潜在空间方差取决于每个维度。遍历是通过固定潜在向量的所有维度来完成的,除了一个选定的维度和改变所选维度在其范围内的值。潜在空间的某些维度可能对应于数据的特定属性(VAE没有特定的机制来强制这种行为,但它可能会发生)。例如,潜在空间中的一个维度可以控制面部的情绪表达或物体的方向。

每种采样方法都提供了一种不同的方式来探索和理解由VAE潜在空间捕获的数据属性。

① Posterior Sampling(From a Given Input Image)

img

编码器在潜在空间输出一个分布(μ_x和σ_x的正态分布)。从正态分布 N ( μ _ x , σ _ x ) N(μ\_x,σ\_x) N(μ_x,σ_x)中采样并将采样向量传递给解码器将生成与给定输入图像相似的图像。

def posterior_sampling(model, data_loader, n_samples=25):
  model.eval()
  images, _ = next(iter(data_loader))
  images = images[:n_samples]
  with torch.no_grad():
    _, encoding_dist = model(images.to(device))
    input_sample=encoding_dist.sample()
    recon_images = model.decoder(input_sample)
    show_images_grid(images, title=f'input samples')
    show_images_grid(recon_images, title=f'generated posterior samples')
posterior_sampling(vae, train_loader, n_samples=30)

后验采样允许生成真实的数据样本,但具有低可变性:输出数据与输入数据相似。

img

img

② Prior Sampling(From a Random Latent Space Vector)

img

从分布中采样并将采样向量传递给解码器允许生成新数据。

def prior_sampling(model, z_dim=32, n_samples = 25):
  model.eval()
  input_sample=torch.randn(n_samples, z_dim).to(device)
  with torch.no_grad():
    sampled_images = model.decoder(input_sample)
  show_images_grid(sampled_images, title=f'generated prior samples')
prior_sampling(vae, z_dim, n_samples=40)

N ( 0 , 1 ) N(0,1) N(0,1)的先验采样并不总是产生可信的数据,而是具有很高的可变性。

img

③ Sampling From Class Centers

img

每个类的平均编码可以从整个数据集中累积,然后用于受控(条件生成)。

def get_data_predictions(model, data_loader):
    model.eval()
    latents_mean = []
    latents_std = []
    labels = []
    with torch.no_grad():
        for i, (data, label) in enumerate(data_loader):
          mu, std = model.encoder(data.to(device))
          latents_mean.append(mu.cpu())
          latents_std.append(std.cpu())
          labels.append(label.cpu())
    latents_mean = torch.cat(latents_mean, dim=0)
    latents_std = torch.cat(latents_std, dim=0)
    labels = torch.cat(labels, dim=0)
    return latents_mean, latents_std, labels
def get_classes_mean(class_to_idx, labels, latents_mean, latents_std):
  classes_mean = {}
  for class_name in train_loader.dataset.class_to_idx:
    class_id = train_loader.dataset.class_to_idx[class_name]
    labels_class = labels[labels==class_id]
    latents_mean_class = latents_mean[labels==class_id]
    latents_mean_class = latents_mean_class.mean(dim=0, keepdims=True)

    latents_std_class = latents_std[labels==class_id]
    latents_std_class = latents_std_class.mean(dim=0, keepdims=True)

    classes_mean[class_id] = [latents_mean_class, latents_std_class]
  return classes_mean
latents_mean, latents_stdvar, labels = get_data_predictions(vae, train_loader)
classes_mean = get_classes_mean(train_loader.dataset.class_to_idx, labels, latents_mean, latents_stdvar)
n_samples = 20
for class_id in classes_mean.keys():
  latents_mean_class, latents_stddev_class = classes_mean[class_id]
  # create normal distribution of the current class
  class_dist = Normal(latents_mean_class, latents_stddev_class)
  percentiles = torch.linspace(0.05, 0.95, n_samples)
  # get samples from different parts of the distribution using icdf
  # https://pytorch.org/docs/stable/distributions.html#torch.distributions.distribution.Distribution.icdf 
  class_z_sample = class_dist.icdf(percentiles[:, None].repeat(1, z_dim))
  with torch.no_grad():
    # generate image directly from mean
    class_image_prototype = vae.decoder(latents_mean_class.to(device))
    # generate images sampled from Normal(class mean, class std) 
    class_images = vae.decoder(class_z_sample.to(device))
  show_image(class_image_prototype[0].cpu(), title=f'Class {class_id} prototype image')
  show_images_grid(class_images.cpu(), title=f'Class {class_id} images')

从平均类为 μ μ μ的正态分布中抽样可以保证从同一类中生成新数据。

img

从类别3中心中生成图像。

img

img

从类别4中心中生成图像。

img

在icdf中使用的低和高百分位数值导致高数据方差。

④ 插值(Interpolation)

img

def linear_interpolation(start, end, steps):
    # Create a linear path from start to end
    z = torch.linspace(0, 1, steps)[:, None].to(device) * (end - start) + start
    # Decode the samples along the path
    vae.eval()
    with torch.no_grad():
      samples = vae.decoder(z)
    return samples

在两个随机潜在向量之间的插值:

start = torch.randn(1, z_dim).to(device)
end = torch.randn(1, z_dim).to(device)

interpolated_samples = linear_interpolation(start, end, steps = 24)
show_images_grid(interpolated_samples, title=f'Linear interpolation between two random latent vectors')

img

在两个类中心之间的插值:

for start_class_id in range(1,10):
  start = classes_mean[start_class_id][0].to(device)
  for end_class_id in range(1, 10):
    if end_class_id == start_class_id:
      continue
    end = classes_mean[end_class_id][0].to(device)
    interpolated_samples = linear_interpolation(start, end, steps = 20)
    show_images_grid(interpolated_samples, title=f'Linear interpolation between classes {start_class_id} and {end_class_id}')

img

img

⑤ 潜在空间遍历(Latent Space Traversal)

潜在向量的每一个维度表示一个正态分布;维度的取值范围由维度的均值和方差控制。遍历值范围的一种简单方法是使用正态分布的逆累积分布函数(Cumulative Distribution Function,CDF)。

ICDF接受0到1之间的值(表示概率),并从分布中返回一个值。对于给定概率p,ICDF输出p_idcf值,使得随机变量≤p_icdf的概率等于给定概率p。

如果你有一个正态分布,icdf(0.5)应该返回分布的平均值。icdf(0.95)返回的值应该大于分布中95%的数据。

img

单维度潜在空间遍历(Single Dimension Latent Space Traversal):

img

def latent_space_traversal(model, input_sample, norm_dist, dim_to_traverse, n_samples, latent_dim, device):
    # Create a range of values to traverse
    assert input_sample.shape[0] == 1, 'input sample shape should be [1, latent_dim]'
    # Generate linearly spaced percentiles between 0.05 and 0.95
    percentiles = torch.linspace(0.1, 0.9, n_samples)
    # Get the quantile values corresponding to the percentiles
    traversed_values = norm_dist.icdf(percentiles[:, None].repeat(1, z_dim))
    # Initialize a latent space vector with zeros
    z = input_sample.repeat(n_samples, 1)
    # Assign the traversed values to the specified dimension
    z[:, dim_to_traverse] = traversed_values[:, dim_to_traverse]

    # Decode the latent vectors
    with torch.no_grad():
        samples = model.decoder(z.to(device))
    return samples
for class_id in range(0,10):
  mu, std = classes_mean[class_id]
  with torch.no_grad():
    recon_images = vae.decoder(mu.to(device))
  show_image(recon_images[0], title=f'class {class_id} mean sample')
  for i in range(z_dim):
    interpolated_samples = latent_space_traversal(vae, mu, norm_dist=Normal(mu, torch.ones_like(mu)), dim_to_traverse=i, n_samples=20, latent_dim=z_dim, device=device)
    show_images_grid(interpolated_samples, title=f'Class {class_id} dim={i} traversal')

遍历单个维度可能导致数字样式或控制数字方向的改变。

img

img

img

img

img

img

img

二维度潜在空间遍历(Two Dimension Latent Space Traversal):

img

def traverse_two_latent_dimensions(model, input_sample, z_dist, n_samples=25, z_dim=16, dim_1=0, dim_2=1, title='plot'):
  digit_size=28

  percentiles = torch.linspace(0.10, 0.9, n_samples)

  grid_x = z_dist.icdf(percentiles[:, None].repeat(1, z_dim))
  grid_y = z_dist.icdf(percentiles[:, None].repeat(1, z_dim))

  figure = np.zeros((digit_size * n_samples, digit_size * n_samples))

  z_sample_def = input_sample.clone().detach()
  # select two dimensions to vary (dim_1 and dim_2) and keep the rest fixed
  for yi in range(n_samples):
      for xi in range(n_samples):
          with torch.no_grad():
              z_sample = z_sample_def.clone().detach()
              z_sample[:, dim_1] = grid_x[xi, dim_1]
              z_sample[:, dim_2] = grid_y[yi, dim_2]
              x_decoded = model.decoder(z_sample.to(device)).cpu()
          digit = x_decoded[0].reshape(digit_size, digit_size)
          figure[yi * digit_size: (yi + 1) * digit_size,
                 xi * digit_size: (xi + 1) * digit_size] = digit.numpy()

  plt.figure(figsize=(6, 6))
  plt.imshow(figure, cmap='Greys_r')
  plt.title(title)
  plt.show()
for class_id in range(10):
  mu, std = classes_mean[class_id]
  with torch.no_grad():
    recon_images = vae.decoder(mu.to(device))
  show_image(recon_images[0], title=f'class {class_id} mean sample')
  traverse_two_latent_dimensions(vae, mu, z_dist=Normal(mu, torch.ones_like(mu)), n_samples=8, z_dim=z_dim, dim_1=3, dim_2=6, title=f'Class {class_id} traversing dimensions {(3, 6)}')

一次遍历多个维度提供了一种可控的方式来生成具有高可变性的数据。

img

img

img

img

img

img

来自潜在空间的2D流形的数字:

如果使用 z _ d i m = 2 z\_dim=2 z_dim=2训练VAE模型,则可以从其潜在空间显示2D数字流形。为此,我将使用具有 d i m _ 1 = 0 dim\_1=0 dim_1=0 d i m _ 2 = 2 dim\_2=2 dim_2=2的traverse_two_latent_dimensions函数。

vae_2d = train_model(epochs=10, z_dim=2)
z_dist = Normal(torch.zeros(1, 2), torch.ones(1, 2))
input_sample = torch.zeros(1, 2)
with torch.no_grad():
  decoding = vae_2d.decoder(input_sample.to(device))

traverse_two_latent_dimensions(vae_2d, input_sample, z_dist, n_samples=20, dim_1=0, dim_2=1, z_dim=2, title=f'traversing 2D latent space')

img

3、参考资料
  • 10
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值