VQGAN理论加代码一对一详解,小白向解析

最近在看图像生成相关论文,记录一下学习内容。感觉只看论文有点干巴,所以理论代码一对一上。

整体网络框架

VQGAN (Vector Quantized Generative Adversarial Network) 是一种基于 GAN 的生成模型,可以将图像或文本转换为高质量的图像。

  • VQ (Vector Quantization)是一种数据压缩技术,是指将连续数据表示为离散化的向量。输入的图像或文本被映射到 VQ 空间中的离散化向量表示,然后,离散化向量然后被送到 GAN 模型中进行图像生成。(参见上图的下半部分)在训练过程中,VQGAN 模型会优化两个损失函数:一个用于量化误差(即离散化向量和连续值之间的误差),另一个用于生成器和判别器之间的对抗损失。
  • GAN 是由生成器和判别器两个模型组成的,生成器负责生成图像,判别器负责判断生成的图像是否为真实的图像。在训练过程中,生成器和判别器相互博弈,不断优化各自的参数,以使生成的图像更接近真实图像。

在这里插入图片描述
上图是论文的总体模型图。下面具体来看看如何实现的。

训练过程

VQGAN整体模型需要两步训练。

  • 第一步通过自监督学习训练CNN Encoder,CNN Decoder,和Codebook;
  • 第二步在已训练好的CNN Encoder和Codebook基础上,通过将code随机替换加入强噪声,用Transformer去重建其code组,来提高Transformer的泛化能力。

第一步——CNN Encoder,CNN Decoder,Codebook

如上图所示,从一张输入图片开始(一般是RGB图片) x ∈ R H × W × 3 x \in \mathbb{R}^{H\times W×3} xRH×W×3,其通过CNN Encoder编码后得到中间特征变量 z ^ ∈ R h × w × n z \hat z \in \mathbb{R}^{h\times w×n_z} z^Rh×w×nz。这时再引入一个codebook,注意,如果是普通的AutoEncoder,则会将 z ^ \hat z z^ 直接送入解码器中进行图像重建。而在VQVAE/VQGAN中,会将 z ^ \hat z z^进行进一步离散化编码成 z q ∈ R h × w × n z z_q\in \mathbb{R}^{h\times w×n_z} zqRh×w×nz

具体做法为:预先生成一个离散数值的codebook Z = { z k } k = 1 K , z k ∈ R n z \mathcal Z=\{z_k\}_{k=1}^{K},z_k \in \mathbb{R}^{n_z} Z={zk}k=1K,zkRnz,在 z ^ \hat z z^ 的每一个编码位置都去 Z \mathcal Z Z中去寻找其距离最近的code,生成具有相同维度的变量。特别注意,这里 z ^ , z q \hat z,z_q z^,zq Z \mathcal Z Z中的单个编码特征的维度都为 n z n_z nz。这一步离散编码的过程就叫做“quantization”, 也就是上面的那个公式。

这样一来,就可以在已经数值离散化的 z q z_q zq基础上使用CNN Decoder进行解码:
x ^ = G ( z q ) = G ( q ( E ( x ) ) ) \hat x=G(z_q)=G(q(E(x))) x^=G(zq)=G(q(E(x)))

整个过程的自监督损失如下:
L V Q ( E , G , Z ) = ∣ ∣ x − x ^ ∣ ∣ 2 + ∣ ∣ s g [ E ( x ) ] − z q ∣ ∣ 2 + ∣ ∣ s g ( z q ) − E ( x ) ∣ ∣ 2 \mathcal L_{VQ}(E,G,Z)=||x-\hat x||^2+||sg[E(x)]-z_q||^2+||sg(z_q)-E(x)||^2 LVQ(E,G,Z)=∣∣xx^2+∣∣sg[E(x)]zq2+∣∣sg(zq)E(x)2其中,上式中的第一项 L r e c \mathcal L_{rec} Lrec 为重建损失(reconstruction loss) s g [ ⋅ ] sg[·] sg[] 为梯度终止操作(stop-gradient operation),其目的在于保证神经网络梯度可以正常回传,而不受离散编码的影响。因此在codebook的搭建过程中,我们看到由 z ^ \hat z z^得到 z q z_q zq之后,先计算出公式中后两项损失,然后又增加了一步detach操作。

loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)
z_q = z + (z_q - z).detach()

这么一来,在其后面计算 L r e c \mathcal L_{rec} Lrec,即公式的第一项中, z q z_q zq的梯度可以顺利复制到 z ^ \hat z z^上,而不受离散编码过程的干扰。除了这个重建过程使用的自监督损失外,还加入了GAN中的对抗loss。文章里没有具体写对抗loss的类型。通过源码可以发现使用的是hinge loss。对于判别器而言,其损失函数可以笼统地表示为:
L G A N ( { E , G , Z } , D ) = l o g D ( x ) + l o g ( 1 − D ( x ^ ) ) \mathcal L_{GAN}(\{E,G,\mathcal Z\}, D)=logD(x)+log(1-D(\hat x)) LGAN({E,G,Z},D)=logD(x)+log(1D(x^))

所以总的误差可以写成:
L = L V Q + λ L G A N \mathcal L = \mathcal L_{VQ}+\lambda \mathcal L_{GAN} L=LVQ+λLGAN

总结来说就是:
x → z ^ → z q → x ^ x\to \hat z\to z_q\to \hat x xz^zqx^
下面主要来看看这三部分的代码
CNN Encoder, CNN Decoder是一种基于UNet的代码结构,具体细节可以从原文中获取,这里不在细说

CNN Encoder

class Encoder(nn.Module):
    def __init__(self, args):
        super(Encoder, self).__init__()
        channels = [128, 128, 128, 256, 256, 512]
        attn_resolutions = [16]
        num_res_blocks = 2
        resolution = 256
        layers = [nn.Conv2d(args.image_channels, channels[0], 3, 1, 1)]
        for i in range(len(channels)-1):
            in_channels = channels[i]
            out_channels = channels[i + 1]
            for j in range(num_res_blocks):
                layers.append(ResidualBlock(in_channels, out_channels))
                in_channels = out_channels
                if resolution in attn_resolutions:
                    layers.append(NonLocalBlock(in_channels))
            if i != len(channels)-2:
                layers.append(DownSampleBlock(channels[i+1]))
                resolution //= 2
        layers.append(ResidualBlock(channels[-1], channels[-1]))
        layers.append(NonLocalBlock(channels[-1]))
        layers.append(ResidualBlock(channels[-1], channels[-1]))
        layers.append(GroupNorm(channels[-1]))
        layers.append(Swish())
        layers.append(nn.Conv2d(channels[-1], args.latent_dim, 3, 1, 1))
        self.model = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.model(x)

具体的模块定义可以阅读源代码,这个都不难理解。

CNN Decoder

class Decoder(nn.Module):
    def __init__(self, args):
        super(Decoder, self).__init__()
        channels = [512, 256, 256, 128, 128]
        attn_resolutions = [16]
        num_res_blocks = 3
        resolution = 16

        in_channels = channels[0]
        layers = [nn.Conv2d(args.latent_dim, in_channels, 3, 1, 1),
                  ResidualBlock(in_channels, in_channels),
                  NonLocalBlock(in_channels),
                  ResidualBlock(in_channels, in_channels)]

        for i in range(len(channels)):
            out_channels = channels[i]
            for j in range(num_res_blocks):
                layers.append(ResidualBlock(in_channels, out_channels))
                in_channels = out_channels
                if resolution in attn_resolutions:
                    layers.append(NonLocalBlock(in_channels))
            if i != 0:
                layers.append(UpSampleBlock(in_channels))
                resolution *= 2
               
        layers.append(GroupNorm(in_channels))
        layers.append(Swish())
        layers.append(nn.Conv2d(in_channels, args.image_channels, 3, 1, 1))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

Codebook

我最开始看的时候,最不明白的地方就是这个codebook,一直在想,这兄弟是哪蹦出来的。其实就是另外定义的一个网络,说白了甚至算不上一个网络就是一个nn.Embedding(),还是之前没看VQVAE的锅。

class Codebook(nn.Module):
    def __init__(self, args):
        super(Codebook, self).__init__()
        self.num_codebook_vectors = args.num_codebook_vectors
        self.latent_dim = args.latent_dim
        self.beta = args.beta
        self.embedding = nn.Embedding(self.num_codebook_vectors, self.latent_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.num_codebook_vectors, 1.0 / self.num_codebook_vectors)

    def forward(self, z):
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, self.latent_dim)
        d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
            torch.sum(self.embedding.weight**2, dim=1) - \
            2*(torch.matmul(z_flattened, self.embedding.weight.t()))

        min_encoding_indices = torch.argmin(d, dim=1)
        z_q = self.embedding(min_encoding_indices).view(z.shape)
        loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)
        z_q = z + (z_q - z).detach()
        z_q = z_q.permute(0, 3, 1, 2)

        return z_q, min_encoding_indices, loss

第二步——Transformer 训练

经VQGAN得到的压缩图像与真实图像有一个本质性的不同:真实图像的像素值具有连续性,相邻的颜色更加相似,而压缩图像的像素值则没有这种连续性。
压缩图像的这一特性让寻找一个压缩图像生成模型变得异常困难。多数强大的真实图像生成模型(比如GAN)都是输出一个连续的浮点颜色值,再做一个浮点转整数的操作,得到最终的像素值。而对于压缩图像来说,这种输出连续颜色的模型都不适用了。而恰好,Transformer天生就支持建模离散的输出。在NLP中,每个单词都可以用一个离散的数字表示。Transformer会不断生成表示单词的数字,以达到生成句子的效果。
VQGAN的作者使用了自回归图像生成模型的常用做法,给图像的每个像素从左到右,从上到下规定一个顺序。有了先后顺序后,图像就可以被视为一个一维句子,可以用Transfomer生成句子的方式来生成图像了。在第i 步,Transformer会根据前i−1 个像素 s < i s_{<i} s<i生成第 i i i 个像素 s i s_i si.

在这里插入图片描述

来看具体实现——训练过程

现在进入第二步,这篇论文毕竟是个图像生成的任务,注意之前的三个零件已经训练好不动了,现在我们需要得到一组排列好的code,送进CNN Decoder中来实现图像生成。那么这组code怎么来的?这就是Transformer发挥作用的地方了。该工作使用的Transformer模型为著名的GPT-2。迁移到VQGAN中,即可理解为先预测一个code,再一步步地通过已经预测好的code去推断下一个code。

code都是从训练好的codebook Z \mathcal Z Z中寻找,就像写文章一样,你有词典了,现在你要从词典中一个字一个字的写成一篇新文章

为了训练Transformer,

  • 将输入图片 x ∈ R H × W × 3 x \in \mathbb{R}^{H\times W×3} xRH×W×3,通过CNN Encoder编码后得到中间特征变量 z ^ ∈ R h × w × n z \hat z \in \mathbb{R}^{h\times w×n_z} z^Rh×w×nz,再将 z ^ \hat z z^进行进一步离散化编码成 z q ∈ R h × w × n z z_q\in \mathbb{R}^{h\times w×n_z} zqRh×w×nz,[注意这部分都是用上一步训练好的模型,这里只做前传,不做梯度回传训练],
  • z q z_q zq 被展平到空间 R h w × n z \mathbb{R}^{hw×n_z} Rhw×nz ,这样得到了 h w hw hw 个排列好的维度为 n z n_z nz的code。
  • 随机将其中的一部分code替换为随机生成的相同维度的向量,输入transformer模型,也即是给特征中加入噪声。接着进行训练,训练损失函数为cross-entropy交叉熵损失。

假设被替换后的code组合的索引为modified_indices,原本 z q z_q zq的code索引为unmodified_indices,那么Transformer的学习过程即为:喂入modified_indices,通过训练学习重构出unmodified_indices。
L t r a n s f o r m e r = E x ∼ p ( x ) [ − l o g p ( s ) ] \mathcal L_{transformer}=\mathbb E_{x\sim p(x)}[-logp(s)] Ltransformer=Exp(x)[logp(s)]

代码具体实现如下:

"""
首先得到由x前传得到的unmodified_indices
"""

sos_tokens = torch.ones(x.shape[0], 1) * self.sos_token 
# (B, 1), sos_token是一个整数,表示从第几个token开始预测,一般为0

mask = torch.bernoulli(self.pkeep * torch.ones(unmodified_indices.shape, device=unmodified_indices.device)) 
# (B, h*w), 元素都为0和1,0的是mask掉的元素,1是保留的元素(比例为pkeep)

mask = mask.round().to(dtype=torch.int64)
random_indices = torch.randint_like(indices, self.transformer.config.vocab_size)    
# (B, h*w), 生成一些任意的indices,用来填充被遮挡的部分
modified_indices= mask * unmodified_indices+ (1 - mask) * random_indices 
# (B, h*w), mask为1(未遮挡)部分仍然保留原始indices,mask为0(遮挡)部分用random_indices填充
modified_indices= torch.cat((sos_tokens, modified_indices), dim=1) 
# (B, h*w+1),将0放到第一个indice前面
targets = unmodified_indices
logits, _ = self.transformer(modified_indices[:, :-1]) 
# logits: (B, h*w, num_codebook_vectors), 意思是h*w个indices处,预测出来的对应每一个codebook_vector的概率


"""
然后再由logits和targets之间计算交叉熵损失
"""
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))

注意这是训练的过程,不是生成的过程。在VQGAN无条件生成图片的过程中,没有任何先验条件,CNN Encoder直接被弃用。我们需要得到一组排列好的code,送进CNN Decoder中来实现图像生成。

  • 12
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
好的,我这里随机给你一段简单的 JavaScript 代码,并进行详解: ```javascript function calculateSum(arr) { let sum = 0; for (let i = 0; i < arr.length; i++) { sum += arr[i]; } return sum; } const numbers = [1, 2, 3, 4, 5]; const result = calculateSum(numbers); console.log(result); ``` 这段代码定义了一个名为 `calculateSum` 的函数,该函数接收一个数组作为参数,并返回该数组中所有元素的和。然后,我们创建了一个数组 `numbers`,其中包含了一些数字。接下来,我们调用 `calculateSum` 函数,并将 `numbers` 数组作为参数传递给它。最后,我们将函数的返回值存储在变量 `result` 中,并将其打印到控制台中。 让我们逐行解释这段代码的详细过程: 1. `function calculateSum(arr) {` 定义了一个名为 `calculateSum` 的函数,并接收一个参数 `arr`,该参数是一个数组。 2. `let sum = 0;` 声明并初始化一个名为 `sum` 的变量,用于保存数组中所有元素的和。初始值为 `0`。 3. `for (let i = 0; i < arr.length; i++) {` 使用 `for` 循环来遍历数组中的每个元素。循环变量 `i` 从 `0` 开始,每次循环增 `1`,直到 `i` 大于等于数组长度 `arr.length`。 4. `sum += arr[i];` 在循环体中,将数组中第 `i` 个元素的值到 `sum` 变量上。 5. `}` 循环结束。 6. `return sum;` 返回 `sum` 变量的值,即数组中所有元素的和。 7. `const numbers = [1, 2, 3, 4, 5];` 声明一个名为 `numbers` 的常量,用于存储一个包含了一些数字的数组。 8. `const result = calculateSum(numbers);` 调用 `calculateSum` 函数,并将 `numbers` 数组作为参数传递给它。将函数的返回值存储在名为 `result` 的常量中。 9. `console.log(result);` 将 `result` 变量的值打印到控制台中。 希望这个例子能帮助你更好地理解 JavaScript 函数和数组的使用。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值