文章目录
论文:Online Clustered Codebook
原文地址:Online Clustered Codebook。本文是在阅读原文时的简要总结和记录。
Abstract
1.本文解决的任务
尽可能消除矢量量化(Vector Quantization, VQ)中出现的码本崩溃(codebook collapse)问题:即码本中只有一小部分码向量(codevectors)接收对其优化有用的梯度(称为活(alive)码向量),而其余码向量从未更新或使用(称为死(dead)码向量)。这一问题限制了码本在需要大容量表示的任务上的有效性。
2.本文提出的方法:Clustering VQ-VAE (CVQ-VAE)
选择编码特征作为锚点(achor)来更新“死”码向量,同时通过原始损失优化“活”码向量。
3. 本文得到的结果
- 所提出的CVQ策略使“死”码向量在分布上更接近编码特征,从而增加了被选择和优化的可能性。
- 在各种数据集、任务(例如重建和生成)和架构(例如 VQ-VAE、VQGAN、LDM)上验证了CVQ的泛化能力。 只需几行代码,CVQ-VAE 就可以轻松集成到现有模型中。
4. 本文的代码和模型地址
https://github.com/lyndonzheng/CVQ-VAE
Introduction
1.动机
在VQ的过程中,量化操作阻止梯度反向传播到代码向量,导致码本崩溃(即只有一小部分码向量与可学习的特征一起进行了优化,而码本中大多数码向量根本没有被使用)。 这一问题极大地限制了VQ的有效性,导致码向量利用率低而无法充分利用码本的表达能力,特别是当码本大小很大时。
2.本文贡献:CVQ-VAE
- CVQ的基本做法:仿照经典的聚类算法(例如k-means和k-means++),通过从学习到的特征中重新采样来动态初始化未优化的码本。
- CVQ的结果:避免码本崩溃,并通过优化所有码向量显着提高较大码本的使用率。
- CVQ的具体做法:计算不同minibatch的特征的运行平均值(running average),并使用它们来改进“死”码向量的动态重新初始化。
3. 实验结果
- CVQ-VAE 在相同设置下的各种数据集上显着优于以前的模型VQ-VAE和SQ-VAE。
- 对该方法的变体进行了彻底的消融实验,以证明CVQ设计的有效性并分析各种设计因素的重要性。
- 将 CVQ-VAE 合并到大型模型中(例如 VQ-GAN 和 LDM),进一步证明了其在各种应用中的通用性和潜力。
Related Work
Jukebox、HVQ-VAE、SQ-VAE、VQ-WAE。
Proposed Approach
考虑一张 x ∈ R H × W × c x\in\mathbb{R}^{H\times W\times c} x∈RH×W×c的图片,编码器输出特征为 z ^ = E ϕ ( x ) \hat{z}=\mathcal{E}_{\phi}(x) z^=Eϕ(x),被量化为离散码字 z q ∈ R h × w × n q z_q\in\mathbb{R}^{h\times w\times n_q} zq∈Rh×w×nq, n q n_q nq是码向量的维度,码本大小为 K K K。
1. 运行平均更新
- 首先累计计算每个training minibatch中码向量的平均使用情况:
其中 n k ( t ) n_k^{(t)} nk(t)是training minibatch中将被量化为码向量 e k e_k ek的特征数量, B h w Bhw Bhw表示batch、高度和宽度上的特征数量。 γ ∈ ( 0 , 1 ) \gamma\in(0,1) γ∈(0,1)
(default γ = 0.99 \gamma=0.99 γ=0.99)是一个衰减超参数, N k ( 0 ) = 0 N_k^{(0)}=0 Nk(0)=0 - 然后,使用选定的锚点更新码向量:从特征
z
^
\hat{z}
z^中选择具有
K
K
K个向量的子集
Z
ˉ
\bar{\mathcal{Z}}
Zˉ作为锚点,由于期望“死”码向量应该比“活”码向量进行更多的修改而不是直接使用锚点来重新初始化“死”码向量,因此使用累积平均使用
N
k
(
t
)
N_k^{(t)}
Nk(t)计算每个码向量
e
k
e_k
ek的衰减值
a
k
(
t
)
a_k^{(t)}
ak(t)并重新初始化特征如下
其中 ϵ \epsilon ϵ是一个小常数以确保为码向量分配不同minibatch的特征平均值, z ^ k ( t ) \hat{z}_k^{(t)} z^k(t)是采样的锚点。
这一步用于更新“死”码向量,且 a k ( t ) a_k^{(t)} ak(t)是基于平均使用计算得到的而不是预先定义的超参,这与VQ-VAE中使用的EMA不同。
2. 锚点的选择
锚点的几种选择方法:
- 随机:从特征中随机采样作为锚点
- 独特:为避免重复的锚点,对特征数量 B h w Bhw Bhw内的整数进行随机排列,然后选择前 K 个特征
- 最近:反向查找每个码向量最接近的特征,即 i = a r g m i n z ^ i ∈ E ϕ ( x ) = ∣ ∣ z ^ i − e k ∣ ∣ i=argmin_{\hat{z}_i\in \mathcal{E}_{\phi}(x)}=||\hat{z}_i-e_k|| i=argminz^i∈Eϕ(x)=∣∣z^i−ek∣∣
- 概率随机:基于码向量和特征之间的距离 D i , k D_{i,k} Di,k,设置概率为 p = exp ( − D i , k ) ∑ i = 1 B h w exp ( − D i , k ) p=\frac{\exp(-D_{i,k})}{\sum_{i=1}^{Bhw}\exp(-D_{i,k})} p=∑i=1Bhwexp(−Di,k)exp(−Di,k)
有趣的是,实验结果表明在线(online)版本对锚点的选择方法并不敏感,而不同的锚点采样方法对离线(offline)版本有直接影响。这表明本文提出的运行平均更新行为是改进的主要原因。
3. 对比损失
为了鼓励码本的稀疏性引入对比损失如下:
对于码向量
e
k
e_k
ek,选择正对(positive pair)为与之最近(基于
D
i
,
k
D_{i,k}
Di,k)的特征
z
^
i
+
\hat{z}_i^{+}
z^i+,采样其他特征作为负对(negative pair)。
Experiment: Image Quantisation
1. 实验设置
- Backbone: VQ-VAE(小数据集)、VQ-GAN(大数据集)
- 数据集:MNIST、CIFAR10、Fashion MNIST、FFHQ、ImageNet
- 指标:SSIM、LPIPS、FID、码本perplexity score( e − ∑ k = 1 K p e k log p e k , p e k = n k ∑ i = 1 K n k e^{-\sum_{k=1}^Kp_{e_k}\log p_{e_k}}, p_{e_k}=\frac{n_k}{\sum_{i=1}^Kn_k} e−∑k=1Kpeklogpek,pek=∑i=1Knknk, n k n_k nk是与码向量 e k e_k ek关联的特征数量)
2. 主要结果
- 不同量化方法在VQ-VAE框架下的对比
- 与图像重建任务的SOTA方法对比
- 不同的码本重置方法
3.消融实验
offline: 使用所选锚点重新初始化“死”码向量,但仅限于第一个training batch
online: 运行平均更新
Experiment: Applications - Image Generation
1.实验设置
采用latent diffusion model(LDM)作为backbone,将其中的VQGAN量化器替换为本文描述的CVQ量化器
2. 实验结果