VAE介绍
背景
早期的机器学习关注分类和预测问题居多,关注生成式的研究较少。而生成式任务具有重要价值和应用场景,例如当样本(如图像样本)数量较少时,不足以支持后续的任务,此时,生成式任务应运而生。它可以生成新的不存在的样本(典型的如人脸生成网站 ,每次都会生成现实中不存在人脸),为下游任务做铺垫。
当前随着人工智能生成内容(Artificial Intelligence Generated Content,AIGC)的发展,越来越多的生成式模型产生,如大模型中的文本生成文本,文本生成图像,文本生成语音等任务。
变分自编码器(Variational Auto-Encoder,VAE)也是一个典型的生成式模型,目前被广泛应用于复杂的模型,是必须掌握的一种模型架构。
原理介绍
VAE的核心思想是得到样本的嵌入表示(编码器Encoder后),通过常见的分布拟合嵌入表示,常见的分布往往是标准正态分布或正态分布。在拟合后,我们就可以生成正态分布的样本,并将这些样本作为解码器的输入,经过解码后得到的就是生成的所需样本。
VAE是AE(Auto-Endoer)的改进,在讲解VAE前,我们先讲解AE的原理。
AE的结构如下:
这里可以简单将AE分为三个部分:编码器,解码器和嵌入表示(图中的粉色部分)。它含有两个MLP,分别是编码器和解码器,两个MLP中的权重是我们需要学习的。
那么怎么学习呢?
我们输入样本,在经历编码器和解码器两个MLP后会得到输出(Decoder后的结果),我们的目的是要求经过编码和解码后的输出要和输入相同,因此损失函数应该定义为输出和输入之间的差别,即重构误差。我们可以采用均方损失等损失函数来表征输入和输出的重构误差。
AE有什么作用呢?
大家可能会问,经过编码和解码,使得输入和输出尽量接近,这到底有什么用。其实,我们这样做主要是为了得到嵌入表示(图中的粉色部分)。往往嵌入表示的维度要小于输入,我们可以使用嵌入来得到对输入降维的目的。
介绍完了AE,再介绍进阶版VAE的原理。
VAE的框架图:
VAE和AE的区别在于,VAE需要在编码器后使用常见的分布(往往是正态分布)去拟合嵌入表示(图中的Latent Space),分布间的损失可以用KL散度表示,KL散度的数学定义比较复杂,我们只需知道它可以度量两个分布间的差异即可。对于不同的分布,KL散度公式不同,读者可自己推导或查询资料。
KL散度作为损失的一部分,通过反向传播更新梯度,得到的嵌入表示和拟合的分布十分接近,这样我们就可以在编码器和解码器之前使用拟合的分布中采样的随机噪声来生成新的数据点。经过解码器解码后就得到新的数据点。
重构误差作为损失的另一部分,可以表征输入和输出的损失。KL散度和重构误差相加,作为VAE的损失。
核心代码讲解
VAE模型的定义:
class VAE(nn.Module):
def __init__(self, input_size = 28*28, hidden_dim = 64):
super(VAE, self).__init__()
self.encoder = nn.Sequential(nn.Linear(input_size,256),nn.ReLU(),nn.Linear(256, 128), nn.ReLU())
self.mean = nn.Linear(128, hidden_dim)
self.logvar = nn.Linear(128, hidden_dim)
self.decoder = nn.Sequential(nn.Linear(hidden_dim, 128), nn.ReLU(), nn.Linear(128,256), nn.ReLU(), nn.Linear(256, input_size), nn.Tanh())
def reparameters(self, mean, logvar):
std = torch.exp(0.5*logvar)
z = torch.randn(std.size(), device=std.device)*std + mean
return z
def forward(self, x):
x = self.encoder(x)
mean = self.mean(x)
logvar = self.logvar(x)
z = self.reparameters(mean, logvar)
new_image = self.decoder(z)
return new_image, mean, logvar
输入的为单通道(黑白)的图像,图片长和宽为28,因为输入尺寸为28*28=784,潜在向量的维度是超参数,可以自由定义,这里我定义为hidden_dim=64.
这里需要注意的是一般的AE就表示为编码器encoder和解码器decoder,但是因为这里我们把encoder的最后一层拿出来,经历两个不同的线形层,分别得到潜在向量的均值以及方差的对数(对方差取对数是为了保证方差为正)。也就是说,这里我们表示的不是编码器之后的潜在向量,而是分别表示潜在向量的均值和方差的对数,有了均值和方差,我们就可以使用标准正态分布去拟合潜在向量。
另外需要注意的是,我们在VAE中定义了reparameters方法,这是重参数操作。因为采样不是连续的,无法反向求梯度。而通过重参数,我们让潜在向量的生成通过模型参数 μ(均值)和 σ(标准差)来控制。
完整代码
模型生成:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.optim import Adam
import argparse
from tqdm import tqdm
def download():
# 将图片转化为张量以及归一化处理
Trans = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.5], std=[0.5])])
# 下载MNIST对应的训练和测试数据集
train_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=Trans,
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=Trans,
)
train_Dataloader = DataLoader(train_data,batch_size=64)
test_Dataloader = DataLoader(test_data,batch_size=999999)
return train_Dataloader, test_Dataloader, train_data, test_data
class VAE(nn.Module):
def __init__(self, input_size = 28*28, hidden_dim = 64):
super(VAE, self).__init__()
self.encoder = nn.Sequential(nn.Linear(input_size,256),nn.ReLU(),nn.Linear(256, 128), nn.ReLU())
self.mean = nn.Linear(128, hidden_dim)
self.logvar = nn.Linear(128, hidden_dim)
self.decoder = nn.Sequential(nn.Linear(hidden_dim, 128), nn.ReLU(), nn.Linear(128,256), nn.ReLU(), nn.Linear(256, input_size), nn.Tanh())
def reparameters(self, mean, logvar):
std = torch.exp(0.5*logvar)
z = torch.randn(std.size(), device=std.device)*std + mean
return z
def forward(self, x):
x = self.encoder(x)
mean = self.mean(x)
logvar = self.logvar(x)
z = self.reparameters(mean, logvar)
new_image = self.decoder(z)
return new_image, mean, logvar
def train(vae, train_Dataloader, optimizer):
for index,(X,_) in tqdm(enumerate(train_Dataloader)):
X = torch.reshape(X, (X.shape[0],-1)).to('cuda')
new_image, mean, logvar = vae(X)
loss_re = ((new_image-X)**2).sum()
loss_kl = -0.5 * torch.sum(1 + logvar - torch.exp(logvar) - mean ** 2)
loss = loss_re + loss_kl
optimizer.zero_grad()
loss.backward()
optimizer.step()
if __name__ == "__main__":
train_Dataloader, test_Dataloader, _, _ = download()
# 参数解析,方便调参
parser = argparse.ArgumentParser(
description='train',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--lr', type=float, default=0.0005)
parser.add_argument('--epoch', type=int, default=100)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = VAE().to(device=device)
optimizer = Adam(vae.parameters(), lr = args.lr)
for epoch in range(args.epoch):
print("training epoch:", epoch)
train(vae, train_Dataloader, optimizer)
torch.save(vae.state_dict(),'./vae.pth')
print("模型保存完毕")
模型测试:
import torch
from VAE import VAE
import torchvision
import numpy as np
import matplotlib.pyplot as plt
new_vae = VAE().to('cuda')
new_vae.load_state_dict(torch.load('./vae.pth'))
print("模型重载成功")
with torch.no_grad():
x = torch.randn((32, 64)).to('cuda')
fake = new_vae.decoder(x).reshape(-1, 1, 28, 28)
img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
# 将图像张量转换为 NumPy 数组
img_grid_fake_np = img_grid_fake.cpu().numpy()
img_grid_fake_np = np.transpose(img_grid_fake_np, (1, 2, 0))
# 使用 matplotlib 显示图像
plt.figure(figsize=(10, 10))
plt.imshow(img_grid_fake_np)
plt.axis('off') # 不显示坐标轴
plt.show()
运行结果的生成图像: