1. VAE
1.1 VAE过程详解
首先输入一张图片(256*256),VAE的encoder之后变成两个Original code(64·64)(包含均值m和方差σ,这里的σ是log之后的)。
- 是怎么变成两个Original code呢?
答:首先经过一个很深的网络提取特征,然后再将这个特征经过相对比较浅层的网络,这两个网络作用就是产生均值m和方差σ - VAE经过初步的Encoder之后得到的内容是什么?
答:是一组特定长度的特征向量,这个特征向量表示的是图像分布的均值和方差,(每个特征“像素点”维度(64·64)都有一个均值m和方差σ)。并不像一些简单CNN之后得到的类似于embedding特征向量。
然后code的每个维度(64·64)都随机从正态分布e中采样一个内容,然后通过 c = e*exp(σ)+m 来完成最终的Encoder采样c。 - 为什么要进行采样呢?
答:这里采样就是为了让模型学会生成一个分布(Deocoder从分布中采样),和AE进行区别,AE就是仅仅从特征中采样,导致AE对没有见过的数据特征不了解,
对于Decoder来说,她接受Encoder输出的从分布中采样好的一个c,进行输出图像
这样就保证了模型能够学习到分布中采样。
1.2 VAE优化目标
优化目标有两个:
第一个是计算由图像x通过Encoder之后得到的q(z|x)和我们想要Encoder拟合到标准正态分布(因为方便随机sample就能通过deocoder生成图像,比较方便)的相似性。
- 必须是拟合到标准正态分布吗?
答:现在我们的编码换成一个连续变量z,我们规定z服从正态分布N(0,1)(实际上并不一定要选用,其他的连续分布都是可行的)。
通过带入KL的公式可以得到一下的优化目标L1
第二个优化目标就是重建损失,就是采样得到的c经过Decoder之后得到的图像x`和原始输入x的MSE
- 为什么要使用KL来表示L1,为什么要有重建损失L2,有么有什么什么原理上的证明?
答:这是通过公式推出来的,具体的可以看李宏毅的介绍和下面blog(右边KL最大下界等等相关的)
blog1 【学习笔记】生成模型——变分自编码器
blog2 Python实战——VAE的理论详解及Pytorch实现
【(强推)李宏毅2021/2022春机器学习课程】 【精准空降到 00:01】
1.3 代码实现
这里借鉴blog2 Python实战——VAE的理论详解及Pytorch实现
1.3.1 AE
- main.py
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from ae import AE
from torch import nn, optim
import matplotlib.pyplot as plt
plt.style.use("ggplot")
def main(epoch_num):
# 下载mnist数据集
mnist_train = datasets.MNIST('mnist', train=True, transform=transforms.Compose([
transforms.ToTensor()
]), download=True)
mnist_test = datasets.MNIST('mnist', train=False, transform=transforms.Compose([
transforms.ToTensor()
]), download=True)
# 载入mnist数据集
# batch_size设置每一批数据的大小,shuffle设置是否打乱数据顺序,结果表明,该函数会先打乱数据再按batch_size取数据
mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)
# 查看每一个batch图片的规模
x, label = iter(mnist_train).__next__() # 取出第一批(batch)训练所用的数据集
print(' img : ', x.shape) # img : torch.Size([32, 1, 28, 28]), 每次迭代获取32张图片,每张图大小为(1,28,28)
# 准备工作 : 搭建计算流程
device = torch.device('cuda')
model = AE().to(device) # 生成AE模型,并转移到GPU上去
print('The structure of our model is shown below: \n')
print(model)
loss_function = nn.MSELoss() # 生成损失函数
optimizer = optim.Adam(model.parameters(), lr=1e-3) # 生成优化器,需要优化的是model的参数,学习率为0.001
# 开始迭代
loss_epoch = []
for epoch in range(epoch_num):
# 每一代都要遍历所有的批次
for batch_index, (x, _) in enumerate(mnist_train):
# [b, 1, 28, 28]
x = x.to(device)
# 前向传播
x_hat = model(x) # 模型的输出,在这里会自动调用model中的forward函数
loss = loss_function(x_hat, x) # 计算损失值,即目标函数
# 后向传播
optimizer.zero_grad() # 梯度清零,否则上一步的梯度仍会存在
loss.backward() # 后向传播计算梯度,这些梯度会保存在model.parameters里面
optimizer.step() # 更新梯度,这一步与上一步主要是根据model.parameters联系起来了
loss_epoch.append(loss.item())
if epoch % (epoch_num // 10) == 0:
print('Epoch [{}/{}] : '.format(epoch, epoch_num), 'loss = ', loss.item()) # loss是Tensor类型
# x, _ = iter(mnist_test).__next__() # 在测试集中取出一部分数据
# with torch.no_grad():
# x_hat = model(x)
return loss_epoch
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
epoch_num = 100
loss_epoch = main(epoch_num=epoch_num)
# 绘制迭代结果
plt.plot(loss_epoch)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()
- ae.py
from torch import nn
class AE(nn.Module):
def __init__(self):
# 调用父类方法初始化模块的state
super(AE, self).__init__()
# 编码器 : [b, 784] => [b, 20]
self.encoder = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 20),
nn.ReLU()
)
# 解码器 : [b, 20] => [b, 784]
self.decoder = nn.Sequential(
nn.Linear(20, 256),
nn.ReLU(),
nn.Linear(256, 784),
nn.Sigmoid() # 图片数值取值为[0,1],不宜用ReLU
)
def forward(self, x):
"""
向前传播部分, 在model_name(inputs)时自动调用
:param x: the input of our training model
:return: the result of our training model
"""
batch_size = x.shape[0] # 每一批含有的样本的个数
# flatten
# tensor.view()方法可以调整tensor的形状,但必须保证调整前后元素总数一致。view不会修改自身的数据,
# 返回的新tensor与原tensor共享内存,即更改一个,另一个也随之改变。
x = x.view(batch_size, 784) # 一行代表一个样本
# encoder
x = self.encoder(x)
# decoder
x = self.decoder(x)
# reshape
x = x.view(batch_size, 1, 28, 28)
return x
1.3.2 VAE
- main.py
import torch
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import save_image
from vae import VAE
import matplotlib.pyplot as plt
import argparse
import os
import shutil
import numpy as np
# plt.style.use("ggplot")
# 设置模型运行的设备
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
# 设置默认参数
parser = argparse.ArgumentParser(description="Variational Auto-Encoder MNIST Example")
parser.add_argument('--result_dir', type=str, default='./VAEResult', metavar='DIR', help='output directory')
parser.add_argument('--save_dir', type=str, default='./checkPoint', metavar='N', help='model saving directory')
parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='batch size for training(default: 128)')
parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train(default: 200)')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed(default: 1)')
parser.add_argument('--resume', type=str, default='', metavar='PATH', help='path to latest checkpoint(default: None)')
parser.add_argument('--test_every', type=int, default=10, metavar='N', help='test after every epochs')
parser.add_argument('--num_worker', type=int, default=1, metavar='N', help='the number of workers')
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate(default: 0.001)')
parser.add_argument('--z_dim', type=int, default=20, metavar='N', help='the dim of latent variable z(default: 20)')
parser.add_argument('--input_dim', type=int, default=28 * 28, metavar='N', help='input dim(default: 28*28 for MNIST)')
parser.add_argument('--input_channel', type=int, default=1, metavar='N', help='input channel(default: 1 for MNIST)')
args = parser.parse_args()
kwargs = {'num_workers': 2, 'pin_memory': True} if cuda else {}
def dataloader(batch_size=128, num_workers=2):
transform = transforms.Compose([
transforms.ToTensor(),
])
# 下载mnist数据集
mnist_train = datasets.MNIST('mnist', train=True, transform=transform, download=True)
mnist_test = datasets.MNIST('mnist', train=False, transform=transform, download=True)
# 载入mnist数据集
# batch_size设置每一批数据的大小,shuffle设置是否打乱数据顺序,结果表明,该函数会先打乱数据再按batch_size取数据
# num_workers设置载入输入所用的子进程的个数
mnist_train = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
mnist_test = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)
classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
return mnist_test, mnist_train, classes
def loss_function(x_hat, x, mu, log_var):
"""
Calculate the loss. Note that the loss includes two parts.
:param x_hat:
:param x:
:param mu:
:param log_var:
:return: total loss, BCE and KLD of our model
"""
# 1. the reconstruction loss.
# We regard the MNIST as binary classification
BCE = F.binary_cross_entropy(x_hat, x, reduction='sum')
# 2. KL-divergence
# D_KL(Q(z|X) || P(z)); calculate in closed form as both dist. are Gaussian
# here we assume that \Sigma is a diagonal matrix, so as to simplify the computation
KLD = 0.5 * torch.sum(torch.exp(log_var) + torch.pow(mu, 2) - 1. - log_var)
# 3. total loss
loss = BCE + KLD
return loss, BCE, KLD
def save_checkpoint(state, is_best, outdir):
"""
每训练一定的epochs后, 判断损失函数是否是目前最优的,并保存模型的参数
:param state: 需要保存的参数,数据类型为dict
:param is_best: 说明是否为目前最优的
:param outdir: 保存文件夹
:return:
"""
if not os.path.exists(outdir):
os.makedirs(outdir)
checkpoint_file = os.path.join(outdir, 'checkpoint.pth') # join函数创建子文件夹,也就是把第二个参数对应的文件保存在'outdir'里
best_file = os.path.join(outdir, 'model_best.pth')
torch.save(state, checkpoint_file) # 把state保存在checkpoint_file文件夹中
if is_best:
shutil.copyfile(checkpoint_file, best_file)
def test(model, optimizer, mnist_test, epoch, best_test_loss):
test_avg_loss = 0.0
with torch.no_grad(): # 这一部分不计算梯度,也就是不放入计算图中去
'''测试测试集中的数据'''
# 计算所有batch的损失函数的和
for test_batch_index, (test_x, _) in enumerate(mnist_test):
test_x = test_x.to(device)
# 前向传播
test_x_hat, test_mu, test_log_var = model(test_x)
# 损害函数值
test_loss, test_BCE, test_KLD = loss_function(test_x_hat, test_x, test_mu, test_log_var)
test_avg_loss += test_loss
# 对和求平均,得到每一张图片的平均损失
test_avg_loss /= len(mnist_test.dataset)
'''测试随机生成的隐变量'''
# 随机从隐变量的分布中取隐变量
z = torch.randn(args.batch_size, args.z_dim).to(device) # 每一行是一个隐变量,总共有batch_size行
# 对隐变量重构
random_res = model.decode(z).view(-1, 1, 28, 28)
# 保存重构结果
save_image(random_res, './%s/random_sampled-%d.png' % (args.result_dir, epoch + 1))
'''保存目前训练好的模型'''
# 保存模型
is_best = test_avg_loss < best_test_loss
best_test_loss = min(test_avg_loss, best_test_loss)
save_checkpoint({
'epoch': epoch, # 迭代次数
'best_test_loss': best_test_loss, # 目前最佳的损失函数值
'state_dict': model.state_dict(), # 当前训练过的模型的参数
'optimizer': optimizer.state_dict(),
}, is_best, args.save_dir)
return best_test_loss
def main():
# Step 1: 载入数据
mnist_test, mnist_train, classes = dataloader(args.batch_size, args.num_worker)
# 查看每一个batch图片的规模
x, label = iter(mnist_train).__next__() # 取出第一批(batch)训练所用的数据集
print(' img : ', x.shape) # img : torch.Size([batch_size, 1, 28, 28]), 每次迭代获取batch_size张图片,每张图大小为(1,28,28)
# Step 2: 准备工作 : 搭建计算流程
model = VAE(z_dim=args.z_dim).to(device) # 生成AE模型,并转移到GPU上去
print('The structure of our model is shown below: \n')
print(model)
optimizer = optim.Adam(model.parameters(), lr=args.lr) # 生成优化器,需要优化的是model的参数,学习率为0.001
# Step 3: optionally resume(恢复) from a checkpoint
start_epoch = 0
best_test_loss = np.finfo('f').max
if args.resume:
if os.path.isfile(args.resume):
# 载入已经训练过的模型参数与结果
print('=> loading checkpoint %s' % args.resume)
checkpoint = torch.load(args.resume)
start_epoch = checkpoint['epoch'] + 1
best_test_loss = checkpoint['best_test_loss']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print('=> loaded checkpoint %s' % args.resume)
else:
print('=> no checkpoint found at %s' % args.resume)
if not os.path.exists(args.result_dir):
os.makedirs(args.result_dir)
# Step 4: 开始迭代
loss_epoch = []
for epoch in range(start_epoch, args.epochs):
# 训练模型
# 每一代都要遍历所有的批次
loss_batch = []
for batch_index, (x, _) in enumerate(mnist_train):
# x : [b, 1, 28, 28], remember to deploy the input on GPU
x = x.to(device)
# 前向传播
x_hat, mu, log_var = model(x) # 模型的输出,在这里会自动调用model中的forward函数
loss, BCE, KLD = loss_function(x_hat, x, mu, log_var) # 计算损失值,即目标函数
loss_batch.append(loss.item()) # loss是Tensor类型
# 后向传播
optimizer.zero_grad() # 梯度清零,否则上一步的梯度仍会存在
loss.backward() # 后向传播计算梯度,这些梯度会保存在model.parameters里面
optimizer.step() # 更新梯度,这一步与上一步主要是根据model.parameters联系起来了
# print statistics every 100 batch
if (batch_index + 1) % 100 == 0:
print('Epoch [{}/{}], Batch [{}/{}] : Total-loss = {:.4f}, BCE-Loss = {:.4f}, KLD-loss = {:.4f}'
.format(epoch + 1, args.epochs, batch_index + 1, len(mnist_train.dataset) // args.batch_size,
loss.item() / args.batch_size, BCE.item() / args.batch_size,
KLD.item() / args.batch_size))
if batch_index == 0:
# visualize reconstructed result at the beginning of each epoch
x_concat = torch.cat([x.view(-1, 1, 28, 28), x_hat.view(-1, 1, 28, 28)], dim=3)
save_image(x_concat, './%s/reconstructed-%d.png' % (args.result_dir, epoch + 1))
# 把这一个epoch的每一个样本的平均损失存起来
loss_epoch.append(np.sum(loss_batch) / len(mnist_train.dataset)) # len(mnist_train.dataset)为样本个数
# 测试模型
if (epoch + 1) % args.test_every == 0:
best_test_loss = test(model, optimizer, mnist_test, epoch, best_test_loss)
return loss_epoch
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
loss_epoch = main()
# 绘制迭代结果
plt.plot(loss_epoch)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()
- vae.py
from torch import nn
import torch
import torch.nn.functional as F
class VAE(nn.Module):
def __init__(self, input_dim=784, h_dim=400, z_dim=20):
# 调用父类方法初始化模块的state
super(VAE, self).__init__()
self.input_dim = input_dim
self.h_dim = h_dim
self.z_dim = z_dim
# 编码器 : [b, input_dim] => [b, z_dim]
self.fc1 = nn.Linear(input_dim, h_dim) # 第一个全连接层
self.fc2 = nn.Linear(h_dim, z_dim) # mu
self.fc3 = nn.Linear(h_dim, z_dim) # log_var
# 解码器 : [b, z_dim] => [b, input_dim]
self.fc4 = nn.Linear(z_dim, h_dim)
self.fc5 = nn.Linear(h_dim, input_dim)
def forward(self, x):
"""
向前传播部分, 在model_name(inputs)时自动调用
:param x: the input of our training model [b, batch_size, 1, 28, 28]
:return: the result of our training model
"""
batch_size = x.shape[0] # 每一批含有的样本的个数
# flatten [b, batch_size, 1, 28, 28] => [b, batch_size, 784]
# tensor.view()方法可以调整tensor的形状,但必须保证调整前后元素总数一致。view不会修改自身的数据,
# 返回的新tensor与原tensor共享内存,即更改一个,另一个也随之改变。
x = x.view(batch_size, self.input_dim) # 一行代表一个样本
# encoder
mu, log_var = self.encode(x) #首先得到mu和log_var,这里的log_var是经过log之后的var
# reparameterization trick
sampled_z = self.reparameterization(mu, log_var) #采样函数
# decoder
x_hat = self.decode(sampled_z)
# reshape
x_hat = x_hat.view(batch_size, 1, 28, 28)
return x_hat, mu, log_var
def encode(self, x):
"""
encoding part
:param x: input image
:return: mu and log_var
"""
h = F.relu(self.fc1(x))
mu = self.fc2(h)
log_var = self.fc3(h)
return mu, log_var
def reparameterization(self, mu, log_var):
"""
Given a standard gaussian distribution epsilon ~ N(0,1),
we can sample the random variable z as per z = mu + sigma * epsilon
:param mu:
:param log_var:
:return: sampled z
"""
sigma = torch.exp(log_var * 0.5) #计算exp(log_var),然后0.5是均方差
eps = torch.randn_like(sigma)
return mu + sigma * eps # 这里的“*”是点乘的意思
def decode(self, z):
"""
Given a sampled z, decode it back to image
:param z:
:return:
"""
h = F.relu(self.fc4(z))
x_hat = torch.sigmoid(self.fc5(h)) # 图片数值取值为[0,1],不宜用ReLU
return x_hat
1.3.4 介绍Stable Diffusion中用到的VAE实现AutoencoderKL
对于Stable Diffusion中用到的VAE是用作Encoder和Decoder将图像编码到latent空间进行扩散的,并不参与训练。这里讲述AutoencoderKL的实现
- class LatentDiffusion(DDPM):
#出现使用的地方是get_input方法
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
cond_key=None, return_original_cond=False, bs=None, return_x=False):
x = super().get_input(batch, k)
if bs is not None:
x = x[:bs]
x = x.to(self.device)
encoder_posterior = self.encode_first_stage(x) #编码到隐空间,这里还只是DiagonalGaussianDistribution类
z = self.get_first_stage_encoding(encoder_posterior).detach() #然后进行sample
......
def get_first_stage_encoding(self, encoder_posterior):
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
return self.scale_factor * z #乘上了scale_factor
- 首先通过yaml文件进行初始化,确定用到哪个VAE,下面节选一个yaml文件中内容,因为对于Stable diffusion编码到隐空间使用的仅仅有first_stage_config,初始化的模型为
ldm/models/autoencoder.py
中的AutoencoderKL
类
......
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
......
- autoencoder.py
from ldm.modules.diffusionmodules.cldm_model import Encoder, Decoder, Decoder_Mix, Decoder_Mix_Mask
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
class AutoencoderKL(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False
):
super().__init__()
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig) #通过上面的yaml来初始化Encoder的网络结构
self.decoder = Decoder(**ddconfig) #通过上面的yaml来初始化Decoder的网络结构
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
#这个层是生成2倍的embed_dim的特征,因为后边要将其拆分为miu和log_var,这里方便就直接两倍的网络代替
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
#decoder的一个conv层
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.use_ema = ema_decay is not None
if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
...
@contextmanager
def ema_scope(self, context=None):
...
def on_train_batch_end(self, *args, **kwargs):
...
def encode(self, x):
h = self.encoder(x) #将image x编码为信息特征h
moments = self.quant_conv(h) #获得miu和log_var
posterior = DiagonalGaussianDistribution(moments)#交给DiagonalGaussianDistribution进行处理
return posterior
def decode(self, z):
z = self.post_quant_conv(z) #将z变为decoder能够编码的通道数
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
#虽然Stable Diffusion没有使用到forward方法,使用的就差一个sample_posterior,虽然这里没有直接sample,但是在后边的使用到了sample
posterior = self.encode(input) # 首先进行encoder,得到的posterior是属于DiagonalGaussianDistribution类的
if sample_posterior:
z = posterior.sample() #其中DiagonalGaussianDistribution有一个sample方法就是从分布中进行采样出一个数据
else:
z = posterior.mode()
dec = self.decode(z) #最后交给decoder就行了
return dec, posterior
# 获取batch对应的元素
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
...
def validation_step(self, batch, batch_idx):
...
def _validation_step(self, batch, batch_idx, postfix=""):
...
def configure_optimizers(self):
...
def get_last_layer(self):
...
@torch.no_grad()
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
...
def to_rgb(self, x):
...
- DiagonalGaussianDistribution类实现
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) #采样
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1,2,3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
- Encoder和Decoder类实现
class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
**ignore_kwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
padding=1)
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
2*z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x, return_fea=False):
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
fea_list = []
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if return_fea:
if i_level==1 or i_level==2:
fea_list.append(h)
if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
if return_fea:
return h, fea_list
return h
class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
attn_type="vanilla", **ignorekwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,)+tuple(ch_mult)
block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res)
print("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,
padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, z):
#assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
if self.tanh_out:
h = torch.tanh(h)
return h
2. GAN
2.1 无条件生成
现在有一个训练好的Generator生成器,这个生成器的功能是生成二次元图像,生成的这个图像只要是二次元可以。所以想要得到的结果是一个复杂的分布,这个分布式二次元图像的分布(由高维向量组成)。则个时候我们只要输入一个低维向量z(符合正态分布或者其他分布)就可以生成二次元图像。
但是对于GAN来说呢,我们要得到一个Generator,为了保证这个生成器的生成效果,我们就要使用一个判别器:输入一个图像,判断是否符合二次元图像的scale。
对于生成器来说,我们要最小化Pg和Pdata之间的距离,这个距离是两个分布之间的距离。如何计算这个距离是有困难的。但是GAN可以克服这个困难,就是利用判别器。
2.2 损失函数
一切损失计算都是在D(判别器)输出处产生的,而D的输出一般是fake/true的判断,所以整体上采用的是二进制交叉熵函数BCELoss。
左边包含两部分minG和maxD。
首先看一下maxD部分,因为训练一般是先保持G(生成器)不变训练D的。D的训练目标是正确区分fake/true,如果我们以1/0代表true/fake,则对第一项E因为输入采样自真实数据所以我们期望D(x)趋近于1,也就是第一项更大。同理第二项E输入采样自G生成数据,所以我们期望D(G(z))趋近于0更好,也就是说第二项又是更大。所以是这一部分是期望训练使得整体更大了,也就是maxD的含义了。
第二部分保持D不变,训练G,这个时候只有第二项E有用了,关键来了,因为我们要迷惑D,所以这时将label设置为1(我们知道是fake,所以才叫迷惑),希望D(G(z))输出接近于1更好,也就是这一项越小越好,这就是minG。当然判别器哪有这么好糊弄,所以这个时候判别器就会产生比较大的误差,误差会更新G,那么G就会变得更好了,这次没有骗过你,只能下次更努力了。
2.3 代码实现
由于GAN有很多相关的变体,这里有一个github比较完全的总结了实现代码。
下面简单的GAN的实现如下:
import argparse
import os
import numpy as np
import math
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch
os.makedirs("images", exist_ok=True)
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)
img_shape = (opt.channels, opt.img_size, opt.img_size)
cuda = True if torch.cuda.is_available() else False
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(opt.latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *img_shape)
return img
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
# Loss function
adversarial_loss = torch.nn.BCELoss()
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
if cuda:
generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
"../../data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
)
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
# ----------
# Training
# ----------
for epoch in range(opt.n_epochs):
for i, (imgs, _) in enumerate(dataloader):
# Adversarial ground truths
valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)
# Configure input
real_imgs = Variable(imgs.type(Tensor))
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
# Generate a batch of images
gen_imgs = generator(z)
# Loss measures generator's ability to fool the discriminator
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Measure discriminator's ability to classify real from generated samples
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
)
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)