VAE和GAN结构、优化目标、代码实现

1. VAE

在这里插入图片描述

1.1 VAE过程详解

首先输入一张图片(256*256),VAE的encoder之后变成两个Original code(64·64)(包含均值m和方差σ,这里的σ是log之后的)。

  1. 是怎么变成两个Original code呢?
    答:首先经过一个很深的网络提取特征,然后再将这个特征经过相对比较浅层的网络,这两个网络作用就是产生均值m和方差σ
  2. VAE经过初步的Encoder之后得到的内容是什么?
    答:是一组特定长度的特征向量,这个特征向量表示的是图像分布的均值和方差,(每个特征“像素点”维度(64·64)都有一个均值m和方差σ)。并不像一些简单CNN之后得到的类似于embedding特征向量。
    然后code的每个维度(64·64)都随机从正态分布e中采样一个内容,然后通过 c = e*exp(σ)+m 来完成最终的Encoder采样c。
  3. 为什么要进行采样呢?
    答:这里采样就是为了让模型学会生成一个分布(Deocoder从分布中采样),和AE进行区别,AE就是仅仅从特征中采样,导致AE对没有见过的数据特征不了解,
    对于Decoder来说,她接受Encoder输出的从分布中采样好的一个c,进行输出图像
    这样就保证了模型能够学习到分布中采样。
    在这里插入图片描述
    在这里插入图片描述

1.2 VAE优化目标

优化目标有两个:
第一个是计算由图像x通过Encoder之后得到的q(z|x)和我们想要Encoder拟合到标准正态分布(因为方便随机sample就能通过deocoder生成图像,比较方便)的相似性。

  1. 必须是拟合到标准正态分布吗?
    答:现在我们的编码换成一个连续变量z,我们规定z服从正态分布N(0,1)(实际上并不一定要选用,其他的连续分布都是可行的)。
    通过带入KL的公式可以得到一下的优化目标L1
    在这里插入图片描述

第二个优化目标就是重建损失,就是采样得到的c经过Decoder之后得到的图像x`和原始输入x的MSE
在这里插入图片描述

  1. 为什么要使用KL来表示L1,为什么要有重建损失L2,有么有什么什么原理上的证明?
    答:这是通过公式推出来的,具体的可以看李宏毅的介绍和下面blog(右边KL最大下界等等相关的)
    blog1 【学习笔记】生成模型——变分自编码器
    blog2 Python实战——VAE的理论详解及Pytorch实现
    【(强推)李宏毅2021/2022春机器学习课程】 【精准空降到 00:01】

在这里插入图片描述

1.3 代码实现

这里借鉴blog2 Python实战——VAE的理论详解及Pytorch实现

1.3.1 AE

  • main.py
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from ae import AE
from torch import nn, optim
import matplotlib.pyplot as plt

plt.style.use("ggplot")


def main(epoch_num):
    # 下载mnist数据集
    mnist_train = datasets.MNIST('mnist', train=True, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test = datasets.MNIST('mnist', train=False, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)

    # 载入mnist数据集
    # batch_size设置每一批数据的大小,shuffle设置是否打乱数据顺序,结果表明,该函数会先打乱数据再按batch_size取数据
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    # 查看每一个batch图片的规模
    x, label = iter(mnist_train).__next__()  # 取出第一批(batch)训练所用的数据集
    print(' img : ', x.shape)  # img :  torch.Size([32, 1, 28, 28]), 每次迭代获取32张图片,每张图大小为(1,28,28)

    # 准备工作 : 搭建计算流程
    device = torch.device('cuda')
    model = AE().to(device)  # 生成AE模型,并转移到GPU上去
    print('The structure of our model is shown below: \n')
    print(model)
    loss_function = nn.MSELoss()  # 生成损失函数
    optimizer = optim.Adam(model.parameters(), lr=1e-3)  # 生成优化器,需要优化的是model的参数,学习率为0.001

    # 开始迭代
    loss_epoch = []
    for epoch in range(epoch_num):
        # 每一代都要遍历所有的批次
        for batch_index, (x, _) in enumerate(mnist_train):
            # [b, 1, 28, 28]
            x = x.to(device)
            # 前向传播
            x_hat = model(x)  # 模型的输出,在这里会自动调用model中的forward函数
            loss = loss_function(x_hat, x)  # 计算损失值,即目标函数
            # 后向传播
            optimizer.zero_grad()  # 梯度清零,否则上一步的梯度仍会存在
            loss.backward()  # 后向传播计算梯度,这些梯度会保存在model.parameters里面
            optimizer.step()  # 更新梯度,这一步与上一步主要是根据model.parameters联系起来了

        loss_epoch.append(loss.item())
        if epoch % (epoch_num // 10) == 0:
            print('Epoch [{}/{}] : '.format(epoch, epoch_num), 'loss = ', loss.item())  # loss是Tensor类型
            # x, _ = iter(mnist_test).__next__()   # 在测试集中取出一部分数据
            # with torch.no_grad():
            #     x_hat = model(x)

    return loss_epoch


# Press the green button in the gutter to run the script.
if __name__ == '__main__':
	epoch_num = 100
	loss_epoch = main(epoch_num=epoch_num)
	# 绘制迭代结果
	plt.plot(loss_epoch)
	plt.xlabel('epoch')
	plt.ylabel('loss')
	plt.show()


  • ae.py
from torch import nn


class AE(nn.Module):

    def __init__(self):
        # 调用父类方法初始化模块的state
        super(AE, self).__init__()

        # 编码器 : [b, 784] => [b, 20]
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 20),
            nn.ReLU()
        )

        # 解码器 : [b, 20] => [b, 784]
        self.decoder = nn.Sequential(
            nn.Linear(20, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()    # 图片数值取值为[0,1],不宜用ReLU
        )

    def forward(self, x):
        """
        向前传播部分, 在model_name(inputs)时自动调用
        :param x: the input of our training model
        :return: the result of our training model
        """
        batch_size = x.shape[0]   # 每一批含有的样本的个数
        # flatten
        # tensor.view()方法可以调整tensor的形状,但必须保证调整前后元素总数一致。view不会修改自身的数据,
        # 返回的新tensor与原tensor共享内存,即更改一个,另一个也随之改变。
        x = x.view(batch_size, 784)  # 一行代表一个样本

        # encoder
        x = self.encoder(x)

        # decoder
        x = self.decoder(x)

        # reshape
        x = x.view(batch_size, 1, 28, 28)
        return x


1.3.2 VAE

  • main.py
import torch
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import save_image
from vae import VAE
import matplotlib.pyplot as plt
import argparse
import os
import shutil
import numpy as np


# plt.style.use("ggplot")

# 设置模型运行的设备
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")

# 设置默认参数
parser = argparse.ArgumentParser(description="Variational Auto-Encoder MNIST Example")
parser.add_argument('--result_dir', type=str, default='./VAEResult', metavar='DIR', help='output directory')
parser.add_argument('--save_dir', type=str, default='./checkPoint', metavar='N', help='model saving directory')
parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='batch size for training(default: 128)')
parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train(default: 200)')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed(default: 1)')
parser.add_argument('--resume', type=str, default='', metavar='PATH', help='path to latest checkpoint(default: None)')
parser.add_argument('--test_every', type=int, default=10, metavar='N', help='test after every epochs')
parser.add_argument('--num_worker', type=int, default=1, metavar='N', help='the number of workers')
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate(default: 0.001)')
parser.add_argument('--z_dim', type=int, default=20, metavar='N', help='the dim of latent variable z(default: 20)')
parser.add_argument('--input_dim', type=int, default=28 * 28, metavar='N', help='input dim(default: 28*28 for MNIST)')
parser.add_argument('--input_channel', type=int, default=1, metavar='N', help='input channel(default: 1 for MNIST)')
args = parser.parse_args()
kwargs = {'num_workers': 2, 'pin_memory': True} if cuda else {}


def dataloader(batch_size=128, num_workers=2):
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    # 下载mnist数据集
    mnist_train = datasets.MNIST('mnist', train=True, transform=transform, download=True)
    mnist_test = datasets.MNIST('mnist', train=False, transform=transform, download=True)

    # 载入mnist数据集
    # batch_size设置每一批数据的大小,shuffle设置是否打乱数据顺序,结果表明,该函数会先打乱数据再按batch_size取数据
    # num_workers设置载入输入所用的子进程的个数
    mnist_train = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
    mnist_test = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)

    classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
    return mnist_test, mnist_train, classes


def loss_function(x_hat, x, mu, log_var):
    """
    Calculate the loss. Note that the loss includes two parts.
    :param x_hat:
    :param x:
    :param mu:
    :param log_var:
    :return: total loss, BCE and KLD of our model
    """
    # 1. the reconstruction loss.
    # We regard the MNIST as binary classification
    BCE = F.binary_cross_entropy(x_hat, x, reduction='sum')

    # 2. KL-divergence
    # D_KL(Q(z|X) || P(z)); calculate in closed form as both dist. are Gaussian
    # here we assume that \Sigma is a diagonal matrix, so as to simplify the computation
    KLD = 0.5 * torch.sum(torch.exp(log_var) + torch.pow(mu, 2) - 1. - log_var)

    # 3. total loss
    loss = BCE + KLD
    return loss, BCE, KLD


def save_checkpoint(state, is_best, outdir):
    """
    每训练一定的epochs后, 判断损失函数是否是目前最优的,并保存模型的参数
    :param state: 需要保存的参数,数据类型为dict
    :param is_best: 说明是否为目前最优的
    :param outdir: 保存文件夹
    :return:
    """
    if not os.path.exists(outdir):
        os.makedirs(outdir)

    checkpoint_file = os.path.join(outdir, 'checkpoint.pth')  # join函数创建子文件夹,也就是把第二个参数对应的文件保存在'outdir'里
    best_file = os.path.join(outdir, 'model_best.pth')
    torch.save(state, checkpoint_file)  # 把state保存在checkpoint_file文件夹中
    if is_best:
        shutil.copyfile(checkpoint_file, best_file)


def test(model, optimizer, mnist_test, epoch, best_test_loss):
    test_avg_loss = 0.0
    with torch.no_grad():  # 这一部分不计算梯度,也就是不放入计算图中去
        '''测试测试集中的数据'''
        # 计算所有batch的损失函数的和
        for test_batch_index, (test_x, _) in enumerate(mnist_test):
            test_x = test_x.to(device)
            # 前向传播
            test_x_hat, test_mu, test_log_var = model(test_x)
            # 损害函数值
            test_loss, test_BCE, test_KLD = loss_function(test_x_hat, test_x, test_mu, test_log_var)
            test_avg_loss += test_loss

        # 对和求平均,得到每一张图片的平均损失
        test_avg_loss /= len(mnist_test.dataset)

        '''测试随机生成的隐变量'''
        # 随机从隐变量的分布中取隐变量
        z = torch.randn(args.batch_size, args.z_dim).to(device)  # 每一行是一个隐变量,总共有batch_size行
        # 对隐变量重构
        random_res = model.decode(z).view(-1, 1, 28, 28)
        # 保存重构结果
        save_image(random_res, './%s/random_sampled-%d.png' % (args.result_dir, epoch + 1))

        '''保存目前训练好的模型'''
        # 保存模型
        is_best = test_avg_loss < best_test_loss
        best_test_loss = min(test_avg_loss, best_test_loss)
        save_checkpoint({
            'epoch': epoch,  # 迭代次数
            'best_test_loss': best_test_loss,  # 目前最佳的损失函数值
            'state_dict': model.state_dict(),  # 当前训练过的模型的参数
            'optimizer': optimizer.state_dict(),
        }, is_best, args.save_dir)

        return best_test_loss


def main():
    # Step 1: 载入数据
    mnist_test, mnist_train, classes = dataloader(args.batch_size, args.num_worker)

    # 查看每一个batch图片的规模
    x, label = iter(mnist_train).__next__()  # 取出第一批(batch)训练所用的数据集
    print(' img : ', x.shape)  # img :  torch.Size([batch_size, 1, 28, 28]), 每次迭代获取batch_size张图片,每张图大小为(1,28,28)

    # Step 2: 准备工作 : 搭建计算流程
    model = VAE(z_dim=args.z_dim).to(device)  # 生成AE模型,并转移到GPU上去
    print('The structure of our model is shown below: \n')
    print(model)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)  # 生成优化器,需要优化的是model的参数,学习率为0.001

    # Step 3: optionally resume(恢复) from a checkpoint
    start_epoch = 0
    best_test_loss = np.finfo('f').max
    if args.resume:
        if os.path.isfile(args.resume):
            # 载入已经训练过的模型参数与结果
            print('=> loading checkpoint %s' % args.resume)
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch'] + 1
            best_test_loss = checkpoint['best_test_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print('=> loaded checkpoint %s' % args.resume)
        else:
            print('=> no checkpoint found at %s' % args.resume)

    if not os.path.exists(args.result_dir):
        os.makedirs(args.result_dir)

    # Step 4: 开始迭代
    loss_epoch = []
    for epoch in range(start_epoch, args.epochs):

        # 训练模型
        # 每一代都要遍历所有的批次
        loss_batch = []
        for batch_index, (x, _) in enumerate(mnist_train):
            # x : [b, 1, 28, 28], remember to deploy the input on GPU
            x = x.to(device)

            # 前向传播
            x_hat, mu, log_var = model(x)  # 模型的输出,在这里会自动调用model中的forward函数
            loss, BCE, KLD = loss_function(x_hat, x, mu, log_var)  # 计算损失值,即目标函数
            loss_batch.append(loss.item())  # loss是Tensor类型

            # 后向传播
            optimizer.zero_grad()  # 梯度清零,否则上一步的梯度仍会存在
            loss.backward()  # 后向传播计算梯度,这些梯度会保存在model.parameters里面
            optimizer.step()  # 更新梯度,这一步与上一步主要是根据model.parameters联系起来了

            # print statistics every 100 batch
            if (batch_index + 1) % 100 == 0:
                print('Epoch [{}/{}], Batch [{}/{}] : Total-loss = {:.4f}, BCE-Loss = {:.4f}, KLD-loss = {:.4f}'
                      .format(epoch + 1, args.epochs, batch_index + 1, len(mnist_train.dataset) // args.batch_size,
                              loss.item() / args.batch_size, BCE.item() / args.batch_size,
                              KLD.item() / args.batch_size))

            if batch_index == 0:
                # visualize reconstructed result at the beginning of each epoch
                x_concat = torch.cat([x.view(-1, 1, 28, 28), x_hat.view(-1, 1, 28, 28)], dim=3)
                save_image(x_concat, './%s/reconstructed-%d.png' % (args.result_dir, epoch + 1))

        # 把这一个epoch的每一个样本的平均损失存起来
        loss_epoch.append(np.sum(loss_batch) / len(mnist_train.dataset))  # len(mnist_train.dataset)为样本个数

        # 测试模型
        if (epoch + 1) % args.test_every == 0:
            best_test_loss = test(model, optimizer, mnist_test, epoch, best_test_loss)
    return loss_epoch


# Press the green button in the gutter to run the script.
if __name__ == '__main__':
    loss_epoch = main()
    # 绘制迭代结果
    plt.plot(loss_epoch)
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.show()

  • vae.py
from torch import nn
import torch
import torch.nn.functional as F


class VAE(nn.Module):

    def __init__(self, input_dim=784, h_dim=400, z_dim=20):
        # 调用父类方法初始化模块的state
        super(VAE, self).__init__()

        self.input_dim = input_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # 编码器 : [b, input_dim] => [b, z_dim]
        self.fc1 = nn.Linear(input_dim, h_dim)  # 第一个全连接层
        self.fc2 = nn.Linear(h_dim, z_dim)  # mu
        self.fc3 = nn.Linear(h_dim, z_dim)  # log_var

        # 解码器 : [b, z_dim] => [b, input_dim]
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, input_dim)

    def forward(self, x):
        """
        向前传播部分, 在model_name(inputs)时自动调用
        :param x: the input of our training model [b, batch_size, 1, 28, 28]
        :return: the result of our training model
        """
        batch_size = x.shape[0]  # 每一批含有的样本的个数
        # flatten  [b, batch_size, 1, 28, 28] => [b, batch_size, 784]
        # tensor.view()方法可以调整tensor的形状,但必须保证调整前后元素总数一致。view不会修改自身的数据,
        # 返回的新tensor与原tensor共享内存,即更改一个,另一个也随之改变。
        x = x.view(batch_size, self.input_dim)  # 一行代表一个样本

        # encoder
        mu, log_var = self.encode(x) #首先得到mu和log_var,这里的log_var是经过log之后的var
        # reparameterization trick
        sampled_z = self.reparameterization(mu, log_var) #采样函数
        # decoder
        x_hat = self.decode(sampled_z)
        # reshape
        x_hat = x_hat.view(batch_size, 1, 28, 28)
        return x_hat, mu, log_var

    def encode(self, x):
        """
        encoding part
        :param x: input image
        :return: mu and log_var
        """
        h = F.relu(self.fc1(x))
        mu = self.fc2(h)
        log_var = self.fc3(h)

        return mu, log_var

    def reparameterization(self, mu, log_var):
        """
        Given a standard gaussian distribution epsilon ~ N(0,1),
        we can sample the random variable z as per z = mu + sigma * epsilon
        :param mu:
        :param log_var:
        :return: sampled z
        """
        sigma = torch.exp(log_var * 0.5) #计算exp(log_var),然后0.5是均方差
        eps = torch.randn_like(sigma)
        return mu + sigma * eps  # 这里的“*”是点乘的意思

    def decode(self, z):
        """
        Given a sampled z, decode it back to image
        :param z:
        :return:
        """
        h = F.relu(self.fc4(z))
        x_hat = torch.sigmoid(self.fc5(h))  # 图片数值取值为[0,1],不宜用ReLU
        return x_hat

1.3.4 介绍Stable Diffusion中用到的VAE实现AutoencoderKL

对于Stable Diffusion中用到的VAE是用作Encoder和Decoder将图像编码到latent空间进行扩散的,并不参与训练。这里讲述AutoencoderKL的实现

  • class LatentDiffusion(DDPM):
#出现使用的地方是get_input方法
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
                  cond_key=None, return_original_cond=False, bs=None, return_x=False):
        x = super().get_input(batch, k)
        if bs is not None:
            x = x[:bs]
        x = x.to(self.device)
        encoder_posterior = self.encode_first_stage(x) #编码到隐空间,这里还只是DiagonalGaussianDistribution类
        z = self.get_first_stage_encoding(encoder_posterior).detach() #然后进行sample
......

def get_first_stage_encoding(self, encoder_posterior):
        if isinstance(encoder_posterior, DiagonalGaussianDistribution):
            z = encoder_posterior.sample()
        elif isinstance(encoder_posterior, torch.Tensor):
            z = encoder_posterior
        else:
            raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
        return self.scale_factor * z #乘上了scale_factor
  • 首先通过yaml文件进行初始化,确定用到哪个VAE,下面节选一个yaml文件中内容,因为对于Stable diffusion编码到隐空间使用的仅仅有first_stage_config,初始化的模型为 ldm/models/autoencoder.py中的AutoencoderKL
......
first_stage_config:
      target: ldm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          #attn_type: "vanilla-xformers"
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity
......
  • autoencoder.py
from ldm.modules.diffusionmodules.cldm_model import Encoder, Decoder, Decoder_Mix, Decoder_Mix_Mask
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution

class AutoencoderKL(pl.LightningModule):
    def __init__(self,
                 ddconfig,
                 lossconfig,
                 embed_dim,
                 ckpt_path=None,
                 ignore_keys=[],
                 image_key="image",
                 colorize_nlabels=None,
                 monitor=None,
                 ema_decay=None,
                 learn_logvar=False
                 ):
        super().__init__()
        self.learn_logvar = learn_logvar
        self.image_key = image_key
        self.encoder = Encoder(**ddconfig) #通过上面的yaml来初始化Encoder的网络结构
        self.decoder = Decoder(**ddconfig) #通过上面的yaml来初始化Decoder的网络结构
        self.loss = instantiate_from_config(lossconfig)
        assert ddconfig["double_z"]
       
        #这个层是生成2倍的embed_dim的特征,因为后边要将其拆分为miu和log_var,这里方便就直接两倍的网络代替
        self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 
        #decoder的一个conv层
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
        self.embed_dim = embed_dim
        if colorize_nlabels is not None:
            assert type(colorize_nlabels)==int
            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
        if monitor is not None:
            self.monitor = monitor

        self.use_ema = ema_decay is not None
        if self.use_ema:
            self.ema_decay = ema_decay
            assert 0. < ema_decay < 1.
            self.model_ema = LitEma(self, decay=ema_decay)
            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

    def init_from_ckpt(self, path, ignore_keys=list()):
        ...

    @contextmanager
    def ema_scope(self, context=None):
        ...

    def on_train_batch_end(self, *args, **kwargs):
        ...

    def encode(self, x):
        h = self.encoder(x) #将image x编码为信息特征h
        moments = self.quant_conv(h) #获得miu和log_var
        posterior = DiagonalGaussianDistribution(moments)#交给DiagonalGaussianDistribution进行处理
        return posterior 

    def decode(self, z):
        z = self.post_quant_conv(z) #将z变为decoder能够编码的通道数
        dec = self.decoder(z) 
        return dec  

    def forward(self, input, sample_posterior=True):
    #虽然Stable Diffusion没有使用到forward方法,使用的就差一个sample_posterior,虽然这里没有直接sample,但是在后边的使用到了sample
        posterior = self.encode(input) # 首先进行encoder,得到的posterior是属于DiagonalGaussianDistribution类的
        if sample_posterior:
            z = posterior.sample() #其中DiagonalGaussianDistribution有一个sample方法就是从分布中进行采样出一个数据
        else:
            z = posterior.mode()
        dec = self.decode(z)	#最后交给decoder就行了
        return dec, posterior

	# 获取batch对应的元素
    def get_input(self, batch, k):
        x = batch[k]
        if len(x.shape) == 3:
            x = x[..., None]
        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
        return x

    def training_step(self, batch, batch_idx, optimizer_idx):
        ...

    def validation_step(self, batch, batch_idx):
        ...

    def _validation_step(self, batch, batch_idx, postfix=""):
        ...

    def configure_optimizers(self): 
        ...

    def get_last_layer(self):
       ...

    @torch.no_grad()
    def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
        ...

    def to_rgb(self, x):
        ...
  • DiagonalGaussianDistribution类实现
class DiagonalGaussianDistribution(object):
    def __init__(self, parameters, deterministic=False):
        self.parameters = parameters
        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
        self.deterministic = deterministic
        self.std = torch.exp(0.5 * self.logvar)
        self.var = torch.exp(self.logvar)
        if self.deterministic:
            self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)

    def sample(self):
        x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) #采样
        return x

    def kl(self, other=None):
        if self.deterministic:
            return torch.Tensor([0.])
        else:
            if other is None:
                return 0.5 * torch.sum(torch.pow(self.mean, 2)
                                       + self.var - 1.0 - self.logvar,
                                       dim=[1, 2, 3])
            else:
                return 0.5 * torch.sum(
                    torch.pow(self.mean - other.mean, 2) / other.var
                    + self.var / other.var - 1.0 - self.logvar + other.logvar,
                    dim=[1, 2, 3])

    def nll(self, sample, dims=[1,2,3]):
        if self.deterministic:
            return torch.Tensor([0.])
        logtwopi = np.log(2.0 * np.pi)
        return 0.5 * torch.sum(
            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
            dim=dims)

    def mode(self):
        return self.mean
  • Encoder和Decoder类实现
class Encoder(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
                 **ignore_kwargs):
        super().__init__()
        if use_linear_attn: attn_type = "linear"
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels

        # downsampling
        self.conv_in = torch.nn.Conv2d(in_channels,
                                       self.ch,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        curr_res = resolution
        in_ch_mult = (1,)+tuple(ch_mult)
        self.in_ch_mult = in_ch_mult
        self.down = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = ch*in_ch_mult[i_level]
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions-1:
                down.downsample = Downsample(block_in, resamp_with_conv)
                curr_res = curr_res // 2
            self.down.append(down)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)

        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        2*z_channels if double_z else z_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, x, return_fea=False):
        # timestep embedding
        temb = None

        # downsampling
        hs = [self.conv_in(x)]
        fea_list = []
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](hs[-1], temb)
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
                hs.append(h)
            if return_fea:
                if i_level==1 or i_level==2:
                    fea_list.append(h)
            if i_level != self.num_resolutions-1:
                hs.append(self.down[i_level].downsample(hs[-1]))

        # middle
        h = hs[-1]
        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)

        # end
        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)

        if return_fea:
            return h, fea_list

        return h

class Decoder(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
                 attn_type="vanilla", **ignorekwargs):
        super().__init__()
        if use_linear_attn: attn_type = "linear"
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.give_pre_end = give_pre_end
        self.tanh_out = tanh_out

        # compute in_ch_mult, block_in and curr_res at lowest res
        in_ch_mult = (1,)+tuple(ch_mult)
        block_in = ch*ch_mult[self.num_resolutions-1]
        curr_res = resolution // 2**(self.num_resolutions-1)
        self.z_shape = (1,z_channels,curr_res,curr_res)
        print("Working with z of shape {} = {} dimensions.".format(
            self.z_shape, np.prod(self.z_shape)))

        # z to block_in
        self.conv_in = torch.nn.Conv2d(z_channels,
                                       block_in,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)

        # upsampling
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks+1):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            up = nn.Module()
            up.block = block
            up.attn = attn
            if i_level != 0:
                up.upsample = Upsample(block_in, resamp_with_conv)
                curr_res = curr_res * 2
            self.up.insert(0, up) # prepend to get consistent order

        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        out_ch,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, z):
        #assert z.shape[1:] == self.z_shape[1:]
        self.last_z_shape = z.shape

        # timestep embedding
        temb = None

        # z to block_in
        h = self.conv_in(z)

        # middle
        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)

        # upsampling
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks+1):
                h = self.up[i_level].block[i_block](h, temb)
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h)
            if i_level != 0:
                h = self.up[i_level].upsample(h)

        # end
        if self.give_pre_end:
            return h

        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        if self.tanh_out:
            h = torch.tanh(h)
        return h

2. GAN

2.1 无条件生成

现在有一个训练好的Generator生成器,这个生成器的功能是生成二次元图像,生成的这个图像只要是二次元可以。所以想要得到的结果是一个复杂的分布,这个分布式二次元图像的分布(由高维向量组成)。则个时候我们只要输入一个低维向量z(符合正态分布或者其他分布)就可以生成二次元图像。
在这里插入图片描述
但是对于GAN来说呢,我们要得到一个Generator,为了保证这个生成器的生成效果,我们就要使用一个判别器:输入一个图像,判断是否符合二次元图像的scale。

在这里插入图片描述

对于生成器来说,我们要最小化Pg和Pdata之间的距离,这个距离是两个分布之间的距离。如何计算这个距离是有困难的。但是GAN可以克服这个困难,就是利用判别器。
在这里插入图片描述

2.2 损失函数

在这里插入图片描述

一切损失计算都是在D(判别器)输出处产生的,而D的输出一般是fake/true的判断,所以整体上采用的是二进制交叉熵函数BCELoss。

左边包含两部分minG和maxD。

首先看一下maxD部分,因为训练一般是先保持G(生成器)不变训练D的。D的训练目标是正确区分fake/true,如果我们以1/0代表true/fake,则对第一项E因为输入采样自真实数据所以我们期望D(x)趋近于1,也就是第一项更大。同理第二项E输入采样自G生成数据,所以我们期望D(G(z))趋近于0更好,也就是说第二项又是更大。所以是这一部分是期望训练使得整体更大了,也就是maxD的含义了。

第二部分保持D不变,训练G,这个时候只有第二项E有用了,关键来了,因为我们要迷惑D,所以这时将label设置为1(我们知道是fake,所以才叫迷惑),希望D(G(z))输出接近于1更好,也就是这一项越小越好,这就是minG。当然判别器哪有这么好糊弄,所以这个时候判别器就会产生比较大的误差,误差会更新G,那么G就会变得更好了,这次没有骗过你,只能下次更努力了。

2.3 代码实现

由于GAN有很多相关的变体,这里有一个github比较完全的总结了实现代码。
下面简单的GAN的实现如下:

import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity


# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

构建卷积神经网络实现VAE的步骤如下: 1. 定义VAE的encoder和decoder模型,其中encoder模型包含多个卷积层和全连接层,将输入的数据压缩成潜在变量z;decoder模型包含多个反卷积层和全连接层,将潜在变量z还原成原始数据。 2. 定义VAE的损失函数,包含两部分:重构误差和KL散度。重构误差表示解码器还原出的数据与原始数据的差异,KL散度则是衡量潜在变量分布与标准正态分布之间的距离。 3. 使用PyTorch的优化器训练VAE模型,通过最小化损失函数来优化模型参数。 4. 训练完成后,利用训练好的VAE模型对数据集进行降噪处理。具体来说,对于每个样本,将其输入encoder模型得到潜在变量z,再将z输入decoder模型得到还原后的数据,最终得到降噪后的数据。 以下是一个简单的代码示例,假设输入数据为一个大小为500x100的张量: ``` python import torch import torch.nn as nn # 定义VAE的encoder和decoder模型 class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() self.conv1 = nn.Conv1d(100, 32, kernel_size=5, stride=2, padding=2) self.conv2 = nn.Conv1d(32, 16, kernel_size=5, stride=2, padding=2) self.fc1 = nn.Linear(400, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 10) self.fc4 = nn.Linear(10, 2) def forward(self, x): x = x.view(-1, 100, 500) x = nn.functional.relu(self.conv1(x)) x = nn.functional.relu(self.conv2(x)) x = x.view(-1, 400) x = nn.functional.relu(self.fc1(x)) x = nn.functional.relu(self.fc2(x)) x = nn.functional.relu(self.fc3(x)) z_mean = self.fc4(x) z_log_var = self.fc4(x) return z_mean, z_log_var class Decoder(nn.Module): def __init__(self): super(Decoder, self).__init__() self.fc1 = nn.Linear(2, 10) self.fc2 = nn.Linear(10, 64) self.fc3 = nn.Linear(64, 128) self.fc4 = nn.Linear(128, 5000) def forward(self, x): x = nn.functional.relu(self.fc1(x)) x = nn.functional.relu(self.fc2(x)) x = nn.functional.relu(self.fc3(x)) x = nn.functional.sigmoid(self.fc4(x)) x = x.view(-1, 100, 500) return x # 定义VAE的损失函数
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值