与传统的自动编码器一样,VAE架构有两个部分:编码器和解码器。传统的AE模型将输入映射到一个潜在空间向量,并从这个向量重建输出。
VAE将输入映射到一个多元正态分布(multivariate normal distribution)中(编码器输出每个潜在维度的均值和方差)。
由于VAE编码器产生一个分布,因此可以通过从该分布中采样并将采样的潜在向量传递给解码器来生成新数据。从生成的分布中采样以生成输出图像意味着VAE允许生成与输入数据相似但相同的新数据。
本文探讨了VAE体系结构的组件,并提供了几种使用VAE模型生成新图像(采样)的方法。所有的代码都可以在Google Colab获得。
1、VAE模型实现
通过最小化重构损失(例如BCE或MSE)来训练AE模型。
自动编码器和变分自动编码器都有两个部分:编码器和解码器。AE的编码器神经网络学习将每个图像映射到潜在空间中的单个向量,解码器学习从编码器的潜在向量重建原始图像。
通过最小化重构损失和KL-散度来训练VAE模型。
VAE的编码器神经网络输出的参数定义了潜在空间的每个维度的概率分布(多元分布)。对于每个输入,编码器为潜在空间的每个维度产生平均值和方差。
输出均值和方差用于定义多元高斯分布。解码器神经网络与AE模型相同。
① VAE损失
训练VAE模型的目标是最大化从提供的潜在向量生成真实图像的可能性。在训练过程中,VAE模型将两个损失最小化。
- reconstruction loss:输入图像和解码器输出之间的差异。
- Kullback-Leibler散度损失(KL散度是两个概率分布之间的统计距离):编码器输出的概率分布与先验分布(标准正态分布)之间的距离,有助于正则化潜在空间。
② Reconstruction Loss
常见的重构损失有二院交叉熵(binary cross-entropy,BCE)和均方误差(mean squared error,MSE)。本文中,我将使用MNIST数据集进行演示。MNIST图像只有一个通道,像素值在0到1之间。
reconstruction_loss = nn.BCELoss(reduction='sum')
③ Kullback-Leibler Divergence
如上所述,KL散度评估两个分布之间的差异。注意它不具有距离的对称性质: K L ( P ∣ ∣ Q ) ! = K L ( Q ∣ ∣ P ) KL(P||Q)!=KL(Q||P) KL(P∣∣Q)!=KL(Q∣∣P)。
需要比较的两个分布是:
- 给定输入图像 x x x的编码器输出的潜在空间: q ( z ∣ x ) q(z|x) q(z∣x)
- 潜在空间先验 p ( z ) p(z) p(z),它被假设为一个正态分布,在每个潜在空间维度 N ( 0 , 1 ) N(0,1) N(0,1)中均值为0,标准差为1。
这样的假设简化了KL散度的计算,并鼓励潜在空间遵循已知的、可管理的分布。
from torch.distributions.kl import kl_divergence
def kl_divergence_loss(z_dist):
return kl_divergence(z_dist,
Normal(torch.zeros_like(z_dist.mean),
torch.ones_like(z_dist.stddev))
).sum(-1).sum()
④ 编码器
class Encoder(nn.Module):
def __init__(self, im_chan=1, output_chan=32, hidden_dim=16):
super(Encoder, self).__init__()
self.z_dim = output_chan
self.encoder = nn.Sequential(
self.init_conv_block(im_chan, hidden_dim),
self.init_conv_block(hidden_dim, hidden_dim * 2),
# double output_chan for mean and std with [output_chan] size
self.init_conv_block(hidden_dim * 2, output_chan * 2, final_layer=True),
)
def init_conv_block(self, input_channels, output_channels, kernel_size=4, stride=2, padding=0, final_layer=False):
layers = [
nn.Conv2d(input_channels, output_channels,
kernel_size=kernel_size,
padding=padding,
stride=stride)
]
if not final_layer:
layers += [
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True)
]
return nn.Sequential(*layers)
def forward(self, image):
encoder_pred = self.encoder(image)
encoding = encoder_pred.view(len(encoder_pred), -1)
mean = encoding[:, :self.z_dim]
logvar = encoding[:, self.z_dim:]
# encoding output representing standard deviation is interpreted as
# the logarithm of the variance associated with the normal distribution
# take the exponent to convert it to standard deviation
return mean, torch.exp(logvar*0.5)
⑤ 解码器
class Decoder(nn.Module):
def __init__(self, z_dim=32, im_chan=1, hidden_dim=64):
super(Decoder, self).__init__()
self.z_dim = z_dim
self.decoder = nn.Sequential(
self.init_conv_block(z_dim, hidden_dim * 4),
self.init_conv_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
self.init_conv_block(hidden_dim * 2, hidden_dim),
self.init_conv_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
)
def init_conv_block(self, input_channels, output_channels, kernel_size=3, stride=2, padding=0, final_layer=False):
layers = [
nn.ConvTranspose2d(input_channels, output_channels,
kernel_size=kernel_size,
stride=stride, padding=padding)
]
if not final_layer:
layers += [
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True)
]
else:
layers += [nn.Sigmoid()]
return nn.Sequential(*layers)
def forward(self, z):
# Ensure the input latent vector z is correctly reshaped for the decoder
x = z.view(-1, self.z_dim, 1, 1)
# Pass the reshaped input through the decoder network
return self.decoder(x)
⑥ VAE模型
要通过随机样本进行反向传播,您需要将随机样本的参数μ和σ移到函数之外,以允许通过参数进行梯度计算。这一步也被称为“重新参数化技巧”(“reparameterization trick”)。
在Pytorch中,您可以使用编码器的输出μ和σ创建一个正态分布,并使用resample()方法从中采样,该方法实现了重参数化技巧:它与torch.randn(z_dim)*stddev+mean相同。
class VAE(nn.Module):
def __init__(self, z_dim=32, im_chan=1):
super(VAE, self).__init__()
self.z_dim = z_dim
self.encoder = Encoder(im_chan, z_dim)
self.decoder = Decoder(z_dim, im_chan)
def forward(self, images):
z_dist = Normal(self.encoder(images))
# sample from distribution with reparametarazation trick
z = z_dist.rsample()
decoding = self.decoder(z)
return decoding, z_dist
⑦ 训练VAE
加载MNIST训练和测试数据:
transform = transforms.Compose([transforms.ToTensor()])
# Download and load the MNIST training data
trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
train_loader = DataLoader(trainset, batch_size=64, shuffle=True)
# Download and load the MNIST test data
testset = datasets.MNIST('.', download=True, train=False, transform=transform)
test_loader = DataLoader(testset, batch_size=64, shuffle=True)
按照上图所示的VAE训练步骤创建一个训练循环:
def train_model(epochs=10, z_dim = 16):
model = VAE(z_dim=z_dim).to(device)
model_opt = torch.optim.Adam(model.parameters())
for epoch in range(epochs):
print(f"Epoch {epoch}")
for images, step in tqdm(train_loader):
images = images.to(device)
model_opt.zero_grad()
recon_images, encoding = model(images)
loss = reconstruction_loss(recon_images, images)+ kl_divergence_loss(encoding)
loss.backward()
model_opt.step()
show_images_grid(images.cpu(), title=f'Input images')
show_images_grid(recon_images.cpu(), title=f'Reconstructed images')
return model
z_dim = 8
vae = train_model(epochs=20, z_dim=z_dim)
⑧ 可视化潜在空间
def visualize_latent_space(model, data_loader, device, method='TSNE', num_samples=10000):
model.eval()
latents = []
labels = []
with torch.no_grad():
for i, (data, label) in enumerate(data_loader):
if len(latents) > num_samples:
break
mu, _ = model.encoder(data.to(device))
latents.append(mu.cpu())
labels.append(label.cpu())
latents = torch.cat(latents, dim=0).numpy()
labels = torch.cat(labels, dim=0).numpy()
assert method in ['TSNE', 'UMAP'], 'method should be TSNE or UMAP'
if method == 'TSNE':
tsne = TSNE(n_components=2, verbose=1)
tsne_results = tsne.fit_transform(latents)
fig = px.scatter(tsne_results, x=0, y=1, color=labels, labels={'color': 'label'})
fig.update_layout(title='VAE Latent Space with TSNE',
width=600,
height=600)
elif method == 'UMAP':
reducer = umap.UMAP()
embedding = reducer.fit_transform(latents)
fig = px.scatter(embedding, x=0, y=1, color=labels, labels={'color': 'label'})
fig.update_layout(title='VAE Latent Space with UMAP',
width=600,
height=600
)
fig.show()
visualize_latent_space(vae, train_loader,
device='cuda' if torch.cuda.is_available() else 'cpu',
method='UMAP', num_samples=10000)
![img](https://img-blog.csdnimg.cn/img_convert/243e641427e7d0bc6bd03956492c0a91.jpeg)
用UMAP可视化VAE模型潜在空间。
2、用VAE采样
从变分自动编码器(VAE)中采样可以生成与训练过程中看到的相似的新数据,这是将VAE与传统AE体系结构区分开来的一个独特方面。
有几种方法可以从VAE中进行采样:
- 后验采样(posterior sampling):从给定输入的后验分布中采样。
- 先验采样(prior sampling):从假设标准正态多变量分布的潜在空间采样。这是可能的,因为假设(在VAE训练中使用)潜在变量是正态分布的。此方法不允许生成具有特定属性的数据(例如,从特定类生成数据)。
- 插值(interpolation):在潜在空间两点之间插值可以揭示潜在空间变量的变化如何对应生成数据的变化。
- 潜在维度遍历(traversal of latent dimensions):遍历潜在维度,数据的VAE潜在空间方差取决于每个维度。遍历是通过固定潜在向量的所有维度来完成的,除了一个选定的维度和改变所选维度在其范围内的值。潜在空间的某些维度可能对应于数据的特定属性(VAE没有特定的机制来强制这种行为,但它可能会发生)。例如,潜在空间中的一个维度可以控制面部的情绪表达或物体的方向。
每种采样方法都提供了一种不同的方式来探索和理解由VAE潜在空间捕获的数据属性。
① Posterior Sampling(From a Given Input Image)
编码器在潜在空间输出一个分布(μ_x和σ_x的正态分布)。从正态分布 N ( μ _ x , σ _ x ) N(μ\_x,σ\_x) N(μ_x,σ_x)中采样并将采样向量传递给解码器将生成与给定输入图像相似的图像。
def posterior_sampling(model, data_loader, n_samples=25):
model.eval()
images, _ = next(iter(data_loader))
images = images[:n_samples]
with torch.no_grad():
_, encoding_dist = model(images.to(device))
input_sample=encoding_dist.sample()
recon_images = model.decoder(input_sample)
show_images_grid(images, title=f'input samples')
show_images_grid(recon_images, title=f'generated posterior samples')
posterior_sampling(vae, train_loader, n_samples=30)
后验采样允许生成真实的数据样本,但具有低可变性:输出数据与输入数据相似。
② Prior Sampling(From a Random Latent Space Vector)
从分布中采样并将采样向量传递给解码器允许生成新数据。
def prior_sampling(model, z_dim=32, n_samples = 25):
model.eval()
input_sample=torch.randn(n_samples, z_dim).to(device)
with torch.no_grad():
sampled_images = model.decoder(input_sample)
show_images_grid(sampled_images, title=f'generated prior samples')
prior_sampling(vae, z_dim, n_samples=40)
N ( 0 , 1 ) N(0,1) N(0,1)的先验采样并不总是产生可信的数据,而是具有很高的可变性。
③ Sampling From Class Centers
每个类的平均编码可以从整个数据集中累积,然后用于受控(条件生成)。
def get_data_predictions(model, data_loader):
model.eval()
latents_mean = []
latents_std = []
labels = []
with torch.no_grad():
for i, (data, label) in enumerate(data_loader):
mu, std = model.encoder(data.to(device))
latents_mean.append(mu.cpu())
latents_std.append(std.cpu())
labels.append(label.cpu())
latents_mean = torch.cat(latents_mean, dim=0)
latents_std = torch.cat(latents_std, dim=0)
labels = torch.cat(labels, dim=0)
return latents_mean, latents_std, labels
def get_classes_mean(class_to_idx, labels, latents_mean, latents_std):
classes_mean = {}
for class_name in train_loader.dataset.class_to_idx:
class_id = train_loader.dataset.class_to_idx[class_name]
labels_class = labels[labels==class_id]
latents_mean_class = latents_mean[labels==class_id]
latents_mean_class = latents_mean_class.mean(dim=0, keepdims=True)
latents_std_class = latents_std[labels==class_id]
latents_std_class = latents_std_class.mean(dim=0, keepdims=True)
classes_mean[class_id] = [latents_mean_class, latents_std_class]
return classes_mean
latents_mean, latents_stdvar, labels = get_data_predictions(vae, train_loader)
classes_mean = get_classes_mean(train_loader.dataset.class_to_idx, labels, latents_mean, latents_stdvar)
n_samples = 20
for class_id in classes_mean.keys():
latents_mean_class, latents_stddev_class = classes_mean[class_id]
# create normal distribution of the current class
class_dist = Normal(latents_mean_class, latents_stddev_class)
percentiles = torch.linspace(0.05, 0.95, n_samples)
# get samples from different parts of the distribution using icdf
# https://pytorch.org/docs/stable/distributions.html#torch.distributions.distribution.Distribution.icdf
class_z_sample = class_dist.icdf(percentiles[:, None].repeat(1, z_dim))
with torch.no_grad():
# generate image directly from mean
class_image_prototype = vae.decoder(latents_mean_class.to(device))
# generate images sampled from Normal(class mean, class std)
class_images = vae.decoder(class_z_sample.to(device))
show_image(class_image_prototype[0].cpu(), title=f'Class {class_id} prototype image')
show_images_grid(class_images.cpu(), title=f'Class {class_id} images')
从平均类为 μ μ μ的正态分布中抽样可以保证从同一类中生成新数据。
从类别3中心中生成图像。
从类别4中心中生成图像。
在icdf中使用的低和高百分位数值导致高数据方差。
④ 插值(Interpolation)
def linear_interpolation(start, end, steps):
# Create a linear path from start to end
z = torch.linspace(0, 1, steps)[:, None].to(device) * (end - start) + start
# Decode the samples along the path
vae.eval()
with torch.no_grad():
samples = vae.decoder(z)
return samples
在两个随机潜在向量之间的插值:
start = torch.randn(1, z_dim).to(device)
end = torch.randn(1, z_dim).to(device)
interpolated_samples = linear_interpolation(start, end, steps = 24)
show_images_grid(interpolated_samples, title=f'Linear interpolation between two random latent vectors')
在两个类中心之间的插值:
for start_class_id in range(1,10):
start = classes_mean[start_class_id][0].to(device)
for end_class_id in range(1, 10):
if end_class_id == start_class_id:
continue
end = classes_mean[end_class_id][0].to(device)
interpolated_samples = linear_interpolation(start, end, steps = 20)
show_images_grid(interpolated_samples, title=f'Linear interpolation between classes {start_class_id} and {end_class_id}')
⑤ 潜在空间遍历(Latent Space Traversal)
潜在向量的每一个维度表示一个正态分布;维度的取值范围由维度的均值和方差控制。遍历值范围的一种简单方法是使用正态分布的逆累积分布函数(Cumulative Distribution Function,CDF)。
ICDF接受0到1之间的值(表示概率),并从分布中返回一个值。对于给定概率p,ICDF输出p_idcf值,使得随机变量≤p_icdf的概率等于给定概率p。
如果你有一个正态分布,icdf(0.5)应该返回分布的平均值。icdf(0.95)返回的值应该大于分布中95%的数据。
单维度潜在空间遍历(Single Dimension Latent Space Traversal):
def latent_space_traversal(model, input_sample, norm_dist, dim_to_traverse, n_samples, latent_dim, device):
# Create a range of values to traverse
assert input_sample.shape[0] == 1, 'input sample shape should be [1, latent_dim]'
# Generate linearly spaced percentiles between 0.05 and 0.95
percentiles = torch.linspace(0.1, 0.9, n_samples)
# Get the quantile values corresponding to the percentiles
traversed_values = norm_dist.icdf(percentiles[:, None].repeat(1, z_dim))
# Initialize a latent space vector with zeros
z = input_sample.repeat(n_samples, 1)
# Assign the traversed values to the specified dimension
z[:, dim_to_traverse] = traversed_values[:, dim_to_traverse]
# Decode the latent vectors
with torch.no_grad():
samples = model.decoder(z.to(device))
return samples
for class_id in range(0,10):
mu, std = classes_mean[class_id]
with torch.no_grad():
recon_images = vae.decoder(mu.to(device))
show_image(recon_images[0], title=f'class {class_id} mean sample')
for i in range(z_dim):
interpolated_samples = latent_space_traversal(vae, mu, norm_dist=Normal(mu, torch.ones_like(mu)), dim_to_traverse=i, n_samples=20, latent_dim=z_dim, device=device)
show_images_grid(interpolated_samples, title=f'Class {class_id} dim={i} traversal')
遍历单个维度可能导致数字样式或控制数字方向的改变。
二维度潜在空间遍历(Two Dimension Latent Space Traversal):
def traverse_two_latent_dimensions(model, input_sample, z_dist, n_samples=25, z_dim=16, dim_1=0, dim_2=1, title='plot'):
digit_size=28
percentiles = torch.linspace(0.10, 0.9, n_samples)
grid_x = z_dist.icdf(percentiles[:, None].repeat(1, z_dim))
grid_y = z_dist.icdf(percentiles[:, None].repeat(1, z_dim))
figure = np.zeros((digit_size * n_samples, digit_size * n_samples))
z_sample_def = input_sample.clone().detach()
# select two dimensions to vary (dim_1 and dim_2) and keep the rest fixed
for yi in range(n_samples):
for xi in range(n_samples):
with torch.no_grad():
z_sample = z_sample_def.clone().detach()
z_sample[:, dim_1] = grid_x[xi, dim_1]
z_sample[:, dim_2] = grid_y[yi, dim_2]
x_decoded = model.decoder(z_sample.to(device)).cpu()
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[yi * digit_size: (yi + 1) * digit_size,
xi * digit_size: (xi + 1) * digit_size] = digit.numpy()
plt.figure(figsize=(6, 6))
plt.imshow(figure, cmap='Greys_r')
plt.title(title)
plt.show()
for class_id in range(10):
mu, std = classes_mean[class_id]
with torch.no_grad():
recon_images = vae.decoder(mu.to(device))
show_image(recon_images[0], title=f'class {class_id} mean sample')
traverse_two_latent_dimensions(vae, mu, z_dist=Normal(mu, torch.ones_like(mu)), n_samples=8, z_dim=z_dim, dim_1=3, dim_2=6, title=f'Class {class_id} traversing dimensions {(3, 6)}')
一次遍历多个维度提供了一种可控的方式来生成具有高可变性的数据。
来自潜在空间的2D流形的数字:
如果使用 z _ d i m = 2 z\_dim=2 z_dim=2训练VAE模型,则可以从其潜在空间显示2D数字流形。为此,我将使用具有 d i m _ 1 = 0 dim\_1=0 dim_1=0和 d i m _ 2 = 2 dim\_2=2 dim_2=2的traverse_two_latent_dimensions函数。
vae_2d = train_model(epochs=10, z_dim=2)
z_dist = Normal(torch.zeros(1, 2), torch.ones(1, 2))
input_sample = torch.zeros(1, 2)
with torch.no_grad():
decoding = vae_2d.decoder(input_sample.to(device))
traverse_two_latent_dimensions(vae_2d, input_sample, z_dist, n_samples=20, dim_1=0, dim_2=1, z_dim=2, title=f'traversing 2D latent space')