VQ-VAE(Vector Quantized Variational Autoencoder)代码详解及实现

看这篇文章之前,希望能够熟知以下我写的一遍原理介绍,代码中的名称尽可能和原论文保持一致,如ze,z,zq这些名称等。
VQ-VAE原理
理论不在重新介绍。同样仍然使用mnist数据集。

1. 采样算法模型 pixel cnn

我们知道VQ-VAE并不能像VAE一样自己生产一个随机采样,说白了他不是一个VAE,需要一个辅助模型来生成一个随机采样,论文中用的pixel cnn,当然如果你想了解pixel cnn请看之前写过的一遍介绍,pixel cnn原理 当然本文用的是带有Gate结构的pixel cnn。我们之间看代码。

import torch.nn as nn
import torch
from torchinfo import summary
from torch.functional import F


class VerticalMaskConv2d(nn.Module):

    def __init__(self, *args, **kwags):
        super().__init__()
        self.conv = nn.Conv2d(*args, **kwags)
        H, W = self.conv.weight.shape[-2:]
        mask = torch.zeros((H, W), dtype=torch.float32)
        mask[0:H // 2 + 1] = 1
        mask = mask.reshape((1, 1, H, W))
        self.register_buffer('mask', mask, False)

    def forward(self, x):
        self.conv.weight.data *= self.mask
        conv_res = self.conv(x)
        return conv_res


class HorizontalMaskConv2d(nn.Module):

    def __init__(self, conv_type, *args, **kwags):
        super().__init__()
        assert conv_type in ('A', 'B')
        self.conv = nn.Conv2d(*args, **kwags)
        H, W = self.conv.weight.shape[-2:]
        mask = torch.zeros((H, W), dtype=torch.float32)
        mask[H // 2, 0:W // 2] = 1
        if conv_type == 'B':
            mask[H // 2, W // 2] = 1
        mask = mask.reshape((1, 1, H, W))
        self.register_buffer('mask', mask, False)

    def forward(self, x):
        self.conv.weight.data *= self.mask
        conv_res = self.conv(x)
        return conv_res


class GatedBlock(nn.Module):

    def __init__(self, conv_type, in_channels, p, bn=True):
        super().__init__()
        self.conv_type = conv_type
        self.p = p
        self.v_conv = VerticalMaskConv2d(in_channels, 2 * p, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
        self.v_to_h_conv = nn.Conv2d(2 * p, 2 * p, 1)
        self.bn2 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
        self.h_conv = HorizontalMaskConv2d(conv_type, in_channels, 2 * p, 3, 1,
                                           1)
        self.bn3 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
        self.h_output_conv = nn.Conv2d(p, p, 1)
        self.bn4 = nn.BatchNorm2d(p) if bn else nn.Identity()

    def forward(self, v_input, h_input):
        v = self.v_conv(v_input)
        v = self.bn1(v)
        v_to_h = v[:, :, 0:-1]
        v_to_h = F.pad(v_to_h, (0, 0, 1, 0))
        v_to_h = self.v_to_h_conv(v_to_h)
        v_to_h = self.bn2(v_to_h)

        v1, v2 = v[:, :self.p], v[:, self.p:]
        v1 = torch.tanh(v1)
        v2 = torch.sigmoid(v2)
        v = v1 * v2

        h = self.h_conv(h_input)
        h = self.bn3(h)
        h = h + v_to_h
        h1, h2 = h[:, :self.p], h[:, self.p:]
        h1 = torch.tanh(h1)
        h2 = torch.sigmoid(h2)
        h = h1 * h2
        h = self.h_output_conv(h)
        h = self.bn4(h)
        if self.conv_type == 'B':
            h = h + h_input
        return v, h


class GatedPixelCNN(nn.Module):

    def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):
        super().__init__()
        self.block1 = GatedBlock('A', 1, p, bn)
        self.blocks = nn.ModuleList()
        for _ in range(n_blocks):
            self.blocks.append(GatedBlock('B', p, p, bn))
        self.relu = nn.ReLU()
        self.linear1 = nn.Conv2d(p, linear_dim, 1)
        self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)
        self.out = nn.Conv2d(linear_dim, color_level, 1)

    def forward(self, x):
        v, h = self.block1(x, x)
        for block in self.blocks:
            v, h = block(v, h)
        x = self.relu(h)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.out(x)
        return x


class PixelCnnWithEmbedding(GatedPixelCNN):
    def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):
        super().__init__(n_blocks, p, linear_dim, bn, color_level)
        self.embedding = nn.Embedding(color_level, p)
        self.block1 = GatedBlock('A', p, p, bn)

    def forward(self, x):
        """
        x: (N, H, W), 离散编码z作为输入
        return: (N, 256, H, W)
        """
        x = self.embedding(x)
        x = x.permute(0, 3, 1, 2).contiguous()
        return super().forward(x)


if __name__ == '__main__':

    # net1 = GatedPixelCNN(15, 128, 32)
    # net1.block1 = GatedBlock('A', 128, 128, True)
    # summary(net1, input_size=(1, 128, 28, 28))

    net2 = PixelCnnWithEmbedding(15, 128, 32)
    input_data = torch.randint(0,256, (1,28,28))
    summary(net2, input_data=input_data)

运行之后可以看到模型大致结构,输入和输出的w,h是一致的,这里值得注意的是,此处的pixel cnn是用来生成离散变量z的,而不是原始图片,因为我们本意是获得随机采样的离散变量。

2. VQ-VAE 模型

辅助采样算法有了,那接下来就是VQ-VAE,直接看代码。

import torch.nn as nn
import torch


# 1. 残差块
class ResidualBlock(nn.Module):

    def __init__(self, dim):
        super(ResidualBlock, self).__init__()

        self.res_block = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(dim, dim, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(dim, dim, 1)
        )

    def forward(self, x):
        x = x + self.res_block(x)
        return x


class VQVAE(nn.Module):
    def __init__(self, input_dim, dim, n_embedding):
        """
        input_dim: 输入通道数,比如3,输入的图片是3通道的
        dim:编码后ze的通道数
        n_embedding:code book 向量的个数
        """
        super(VQVAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_dim, dim, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(dim, dim, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(dim, dim, 3, 1, 1),
            ResidualBlock(dim),
            ResidualBlock(dim)
        )

        self.decoder = nn.Sequential(
            nn.Conv2d(dim, dim, 3, 1, 1),
            ResidualBlock(dim),
            ResidualBlock(dim),
            nn.ConvTranspose2d(dim, dim, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(dim, input_dim, 4, 2, 1)
        )
        self.n_downsample = 2

        # code book
        self.vq_embedding = nn.Embedding(n_embedding, dim)
        self.vq_embedding.weight.data.uniform_(-1.0 / n_embedding, 1.0/n_embedding)

    def forward(self, x):
        """
        x, shape(N,C0,H0,W0)
        """
        # encoder (N,C,H,W)
        ze = self.encoder(x)

        # code book embedding [K, C]
        embedding = self.vq_embedding.weight.data

        N, C, H, W = ze.shape
        K, _ = embedding.shape

        embedding_broadcast = embedding.reshape(1, K, C, 1, 1)
        ze_broadcast = ze.reshape(N, 1, C, H, W)

        # 最近距离, 这一步旨在求得zq,这里通过先求ze->z,在求z->zq,事实上z只作为中间变量,通过(zq-ze).detach从计算图分离,避开不能的反向传播
        distance = torch.sum((embedding_broadcast - ze_broadcast) ** 2, 2)  # (N,K,H,W)
        nearest_neghbor = torch.argmin(distance, 1)  # (N,H,W)

        # zq (N, C, H, W) : (N, H, W, C) -> (N, C, H, W)
        zq = self.vq_embedding(nearest_neghbor).permute(0, 3, 1, 2)

        # sg(zq - ze)
        decoder_input = ze + (zq - ze).detach()

        # decoder
        x_hat = self.decoder(decoder_input)

        return x_hat, ze, zq

    # encode z  这一步指在得到离散变量,类似于像素值, 作为输入和标签好用来训练pixel cnn, pixel cnn的目的是用来重建z的,生成z
    @torch.no_grad()
    def encode_z(self, x):
        ze = self.encoder(x)
        embedding = self.vq_embedding.weight.data

        # ze: [N, C, H, W]
        # embedding [K, C]
        N, C, H, W = ze.shape
        K, _ = embedding.shape
        embedding_broadcast = embedding.reshape(1, K, C, 1, 1)
        ze_broadcast = ze.reshape(N, 1, C, H, W)
        distance = torch.sum((embedding_broadcast - ze_broadcast)**2, 2)
        nearest_neighbor = torch.argmin(distance, 1)
        return nearest_neighbor

    # decode z 这一步指在从pixelcnn得到的结果latent生成最终结果, 因为pixel cnn的结果生成的latent 是离散的z
    @torch.no_grad()
    def decode_z(self, latent_z):
        """
        latent: shape, (N, H, W)
        """
        # zq (N, C, H, W)
        zq = self.vq_embedding(latent_z).permute(0,3,1,2)
        x_hat = self.decoder(zq)
        return x_hat

    # shape: [C,H,W]
    def get_latent_HW(self, input_shape):
        C, H, W = input_shape
        return H // 2 ** self.n_downsample, W // 2 ** self.n_downsample


if __name__ == '__main__':
    from torchinfo import summary
    vqvae = VQVAE(1, 32, 32)
    summary(vqvae, input_size=[1,1,28,28])

这里没什么好说的,注意事项全部在代码注释里,特别需要注意的是其中ze,z,zq三者之间的转换,以及detach来分离计算图的使用技巧。

3. 训练及推理代码

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
import cv2
from vqvae import VQVAE
from pixelcnn import GatedPixelCNN, PixelCnnWithEmbedding
import einops
import numpy as np

# 依然拿mnist 作为数据集
# 看一下mnist的样子
def mnist_show():
    mnist = torchvision.datasets.MNIST(root='./data/mnist', download=True)
    print('length of MNIST', len(mnist))
    img, label = mnist[0]
    print(img)
    print(label)
    img.show()
    tensor = transforms.ToTensor()(img)
    print(tensor.shape)   # torch.Size([1, 28, 28])  CHW
    print(tensor.max())    # max 1,
    print(tensor.min())    # min 0, 已经是归一化的结果


# mnist_show()


def train_vqvae(
        model:VQVAE,
        device,
        dataloader,
        ckpt_vqvae='vqvae_ckpt.pth',
        n_epochs=100,
        alpha=1,
        beta=0.25,
):

    model.to(device)  # model = model.to(device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
    mse_loss = torch.nn.MSELoss()

    print("start vqvae train...")
    for epo in range(n_epochs):

        for img, label in dataloader:
            x = img.to(device)  # N1HW
            x_hat, ze, zq = model(x)

            # ||x - decoder(ze+sg(zq-ze))||
            loss_rec = mse_loss(x, x_hat)

            # ||zq - sg(ze)||
            loss_zq = mse_loss(zq, ze.detach())

            # ||sg(zq) - ze||
            loss_ze = mse_loss(zq.detach(), ze)

            loss = loss_rec + alpha * loss_zq + beta * loss_ze

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print(f"epoch:{epo}, loss:{loss.item():.6f}")

        if epo % 10 == 0:
            torch.save(model.state_dict(), ckpt_vqvae)

    print("vqvae train finish!!")


def train_gen(
        vqvae:VQVAE,
        model,
        device,
        dataloader,
        ckpt_gen="gen_ckpt.pth",
        n_epochs=50,
):
    vqvae = vqvae.to(device)
    model = model.to(device)
    vqvae.eval()
    model.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.CrossEntropyLoss()

    print("start pixel cnn train...")
    for epo in range(n_epochs):

        for x, _ in dataloader:

            with torch.no_grad():
                x = x.to(device)

                # 得到离散变量z
                z = vqvae.encode_z(x)

            # 使用pixel cnn重建这个离散变量z,记住是重建的z 而非x 即由z->z
            predict_z = model(z)
            loss = loss_fn(predict_z, z)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(f"epoch:{epo}, loss:{loss.item():.6f}")

        if epo % 10 == 0:
            torch.save(model.state_dict(), ckpt_gen)

    print("pixel train finish!!")


# 看一下vae 的效果
def reconstruct(model, x, device):
    model.to(device)
    model.eval()
    with torch.no_grad():
        x_hat, _, _ = model(x)

    n = x.shape[0]
    n1 = int(n**0.5)
    x_cat = torch.concat((x, x_hat), 3)
    x_cat = einops.rearrange(x_cat, '(n1 n2) c h w -> (n1 h) (n2 w) c', n1=n1)
    x_cat = (x_cat.clip(0, 1) * 255).cpu().numpy().astype(np.uint8)
    cv2.imwrite(f'reconstruct_show.jpg', x_cat)


# 看一下最终生成的效果
def sample_imgs(
        vqvae:VQVAE,
        gen_model,
        img_shape,
        device,
        n_sample=81
):
    vqvae = vqvae.to(device)
    gen_model = gen_model.to(device)

    vqvae.eval()
    gen_model.eval()

    # 获取latent space H,W
    C,H,W = img_shape
    H, W = vqvae.get_latent_HW((C,H,W))

    input_shape = (n_sample, H, W)
    latent_z = torch.zeros(input_shape).to(device).to(torch.long)
    # pixel cnn sample
    with torch.no_grad():
        for i in range(H):
            for j in range(W):
                output = gen_model(latent_z)
                prob_dist = torch.softmax(output[:, :, i, j], -1)
                pixel = torch.multinomial(prob_dist, 1)
                latent_z[:, i, j] = pixel[:, 0]


    # vqvae decode 由z->x_hat
    imgs = vqvae.decode_z(latent_z)

    imgs = imgs * 255
    imgs = imgs.clip(0, 255)
    imgs = einops.rearrange(imgs,
                            '(n1 n2) c h w -> (n1 h) (n2 w) c',
                            n1=int(n_sample**0.5))
    imgs = imgs.detach().cpu().numpy().astype(np.uint8)
    cv2.imwrite('sample_show.jpg', imgs)


def main():

    """ 代码中的公式符号尽可能和原论文一致,避免混淆,尤其是ze,z,zq这几个概念 """

    device = torch.device("cuda:0")
    mnist = torchvision.datasets.MNIST(root='./data/mnist', download=True, transform=transforms.ToTensor())
    dataloader = DataLoader(mnist, batch_size=512, shuffle=True)

    # 0. 构建模型

    vqvae = VQVAE(1, 32, 32)
    gen_model = PixelCnnWithEmbedding(15, 128, 32)

    # 1. train vqvae , reconstruct
    train_vqvae(vqvae, device, dataloader)

    # 2. train gen model, sample

    vqvae.load_state_dict(torch.load('vqvae_ckpt.pth'))
    train_gen(vqvae, gen_model, device, dataloader)

    gen_model.load_state_dict(torch.load('gen_ckpt.pth'))


def test():
    # 训练完成,测试一下效果
    device = torch.device("cuda:0")

    mnist = torchvision.datasets.MNIST(root='./data/mnist', download=True, transform=transforms.ToTensor())
    dataloader = DataLoader(mnist, batch_size=64, shuffle=True)
    batch_imgs, _ = next(iter(dataloader))

    # vqvae
    vqvae = VQVAE(1, 32, 32)
    vqvae.load_state_dict(torch.load('vqvae_ckpt.pth'))
    vqvae.eval()
    vqvae = vqvae.to(device)
    batch_imgs = batch_imgs.to(device)
    reconstruct(vqvae, batch_imgs, device)

    gen_model = PixelCnnWithEmbedding(15, 128, 32)
    gen_model.load_state_dict(torch.load('gen_ckpt.pth'))
    gen_model.eval()
    gen_model = gen_model.to(device)
    sample_imgs(vqvae, gen_model, (1, 28, 28), device)


if __name__ == '__main__':
    main()

    # 训练完成后,运行test()测试效果
    # test()

这里主要需要看三个损失函数是怎么做的,其他没什么注意的。

4. 测试效果

先看vq-vae重建效果:说实话一般的AE重建效果与其差不多,没什么参考意义
在这里插入图片描述
在看生成效果:这个效果比之前的VAE还有pixel cnn和GAN的效果都要好,当然我这只是训练了50个epoch的结果,可以训练更久一点,想必会更好。
在这里插入图片描述

参考

https://github.com/SingleZombie/DL-Demos/blob/master/dldemos/VQVAE/main.py
当然里面很多变量名以及方法做了一定的修改,为了是和论文保持一致,好理解一些。再次说明一下,一定理解ze,z,zq这三者之间的关系转换,以及应用阶段。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

idealmu

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值