文章目录
一、VQ的简单描述
VQ的目标是:将连续变量序列 z z z映射到离散变量序列 z ^ \hat{z} z^。
- VQ包含一个可以训练的码本 E = { e 1 , e 2 , … , e K } E=\{e_1,e_2,\dots,e_K\} E={e1,e2,…,eK},这个码本包含 K K K个不同的码字;
- VQ的输入是一个连续向量序列
z
=
<
z
1
,
z
2
,
…
,
z
T
>
z=<z_1,z_2,\dots,z_T>
z=<z1,z2,…,zT>, 每个
z
i
z_i
zi会被映射到码本中的一个码字,映射规则为:
k i = arg min j ∥ z i − e j ∥ 2 k_i = \mathop{\arg\min}_{j} \| z_{i}-e_j\|_2 ki=argminj∥zi−ej∥2
根据此规则, z i z_i zi会被映射为 e k i e_{k_i} eki; - VQ的输出为离散的量化序列 z ^ = < z 1 ^ , z 2 ^ , … , z T ^ > = < e k 1 , e k 2 , … , e k T > \hat{z}=<\hat{z_1},\hat{z_2},\dots,\hat{z_T}>=<e_{k_1},e_{k_2},\dots,e_{k_T}> z^=<z1^,z2^,…,zT^>=<ek1,ek2,…,ekT>
二、VQ的训练细节
1. 怎么训练 arg min \mathop{\arg\min} argmin
- arg min \mathop{\arg\min} argmin运算不可微, 因此无法直接计算梯度;但可以用straight-through estimator来近似估计梯度
- straight-through estimator
straight-through estimator的思想很简单,就是前向传播的时候使用想要的变量(哪怕不可微),而反向传播求梯度的时候,为不可微的运算使用所设计的梯度。
在VQ的训练过程中 z i z_i zi由于 arg min \mathop{\arg\min} argmin运算不可微,因此在反向传播时,将其梯度设计为:把VQ的输出(即可能是decoder的输入) z i ^ \hat{z_i} zi^的梯度直接复制给 z i z_i zi;可以理解为直接对 z i z_i zi求梯度,而不是先对 z i ^ \hat{z_i} zi^求梯度,再对 z i z_i zi求梯度。 - 由此,整个网络的重构损失本应为 L R e c o n = 1 T ∑ i = 1 T ∥ x i − d e c o d e r ( z i ^ ) ∥ 2 L_{Recon}=\frac{1}{T}\sum_{i=1}^T\|x_i-decoder(\hat{z_i})\|^2 LRecon=T1∑i=1T∥xi−decoder(zi^)∥2,由于将对 z i ^ \hat{z_i} zi^的梯度设计为直接对 z i z_i zi求梯度,因此在反向传播时,重构损失等价于 L R e c o n = 1 T ∑ i = 1 T ∥ x i − d e c o d e r ( z i ) ∥ 2 L_{Recon}=\frac{1}{T}\sum_{i=1}^T\|x_i-decoder(z_i)\|^2 LRecon=T1∑i=1T∥xi−decoder(zi)∥2
2. 怎么为 z i z_i zi找到最接近的码字 e k i e_{k_i} eki,同时更新码本中的码字?
- VQ损失函数可以表示如下,它有两个作用:促使每个
z
i
z_i
zi选择距离最近的码字,同时自动更新码本中的码字
z
^
\hat{z}
z^。
L V Q = 1 T ∑ i = 1 T ∥ z i − z i ^ ∥ 2 = 反 向 传 播 时 1 T ∑ i = 1 T [ ∥ z i − s g ( z i ^ ) ∥ 2 + ∥ s g ( z i ) − z i ^ ∥ 2 ] L_{VQ}=\frac{1}{T}\sum_{i=1}^T\|z_i-\hat{z_i}\|^2\overset{反向传播时}{=}\frac{1}{T}\sum_{i=1}^T[\|z_i-\mathop{sg}(\hat{z_i})\|^2+\|\mathop{sg}(z_i)-\hat{z_i}\|^2] LVQ=T1i=1∑T∥zi−zi^∥2=反向传播时T1i=1∑T[∥zi−sg(zi^)∥2+∥sg(zi)−zi^∥2]
其中 s g ( ⋅ ) sg(·) sg(⋅)表示stop gradient operator - stop gradient operator
stop gradient operator指“不要它的梯度”,即在前向计算时对 s g ( ⋅ ) sg(·) sg(⋅)中的该项正常计算,而在反向传播时不对此项求梯度。
前述的重构损失也可以运用 s g ( ⋅ ) sg(·) sg(⋅)重写作:
L R e c o n = 1 T ∑ i = 1 T ∥ x i − d e c o d e r ( z i + s g ( z i ^ − z i ) ) ∥ 2 L_{Recon}=\frac{1}{T}\sum_{i=1}^T\|x_i-decoder(z_i+sg(\hat{z_i}-z_i))\|^2 LRecon=T1i=1∑T∥xi−decoder(zi+sg(zi^−zi))∥2
这样在前向计算loss时等价于 L R e c o n = 1 T ∑ i = 1 T ∥ x i − d e c o d e r ( z i ^ ) ∥ 2 L_{Recon}=\frac{1}{T}\sum_{i=1}^T\|x_i-decoder(\hat{z_i})\|^2 LRecon=T1∑i=1T∥xi−decoder(zi^)∥2,而在反向传播求梯度时等价于 L R e c o n = 1 T ∑ i = 1 T ∥ x i − d e c o d e r ( z i ) ∥ 2 L_{Recon}=\frac{1}{T}\sum_{i=1}^T\|x_i-decoder(z_i)\|^2 LRecon=T1∑i=1T∥xi−decoder(zi)∥2。 - commitment loss/cost的解释
s g ( ⋅ ) sg(·) sg(⋅)表示在反向传播求梯度时不对此项求梯度。因此 L V Q L_{VQ} LVQ在反向传播时可以等价的拆分为两部分:
(1). ∥ z i − s g ( z i ^ ) ∥ 2 \|z_i-\mathop{sg}(\hat{z_i})\|^2 ∥zi−sg(zi^)∥2因为仅对 z i z_i zi求梯度,目的是“让 z i z_i zi靠近 z i ^ \hat{z_i} zi^”,即为 z i z_i zi选择距离最近的码字 z i ^ = e k i \hat{z_i}=e_{k_i} zi^=eki;这一项称为commitment loss,表示“ z z z commit to z ^ \hat{z} z^的程度”。
(2). ∥ s g ( z i ) − z i ^ ∥ 2 \|\mathop{sg}(z_i)-\hat{z_i}\|^2 ∥sg(zi)−zi^∥2因为仅对 z i ^ \hat{z_i} zi^求梯度,目的是“让 z i ^ \hat{z_i} zi^靠近 z i z_i zi”,即为更新码本中的码字 e k i = z i ^ e_{k_i}=\hat{z_i} eki=zi^。
因为在重构损失中, z z z要尽力保证重构效果,而 z ^ \hat{z} z^相对比较自由;因此在VQ损失中希望"让 z i ^ \hat{z_i} zi^靠近 z i z_i zi"多于" z i z_i zi靠近 z i ^ \hat{z_i} zi^",这一目标可以通过一个系数 α \alpha α ( α < 1 ) (\alpha<1) (α<1)调整两部分的比例来实现,这个 α \alpha α即称作commitment cost:
L V Q = 1 T ∑ i = 1 T [ ∥ s g ( z i ) − z i ^ ∥ 2 + α ∥ z i − s g ( z i ^ ) ∥ 2 ] L_{VQ}=\frac{1}{T}\sum_{i=1}^T[\|\mathop{sg}(z_i)-\hat{z_i}\|^2+\alpha\|z_i-\mathop{sg}(\hat{z_i})\|^2] LVQ=T1i=1∑T[∥sg(zi)−zi^∥2+α∥zi−sg(zi^)∥2]
3. 整个网络的损失
L
=
L
R
e
c
o
n
+
β
L
V
Q
=
1
T
∑
i
=
1
T
∥
x
i
−
d
e
c
o
d
e
r
(
z
i
^
)
∥
2
+
β
(
1
T
∑
i
=
1
T
[
∥
s
g
(
z
i
)
−
z
i
^
∥
2
+
α
∥
z
i
−
s
g
(
z
i
^
)
∥
2
]
)
L=L_{Recon}+\beta L_{VQ}=\frac{1}{T}\sum_{i=1}^T\|x_i-decoder(\hat{z_i})\|^2+\beta (\frac{1}{T}\sum_{i=1}^T[\|\mathop{sg}(z_i)-\hat{z_i}\|^2+\alpha\|z_i-\mathop{sg}(\hat{z_i})\|^2])
L=LRecon+βLVQ=T1i=1∑T∥xi−decoder(zi^)∥2+β(T1i=1∑T[∥sg(zi)−zi^∥2+α∥zi−sg(zi^)∥2])
如前所述,为便于
arg
min
\mathop{\arg\min}
argmin的梯度计算,可以利用stop-gradient operator将整体损失写作:
L
=
1
T
∑
i
=
1
T
∥
x
i
−
d
e
c
o
d
e
r
(
z
i
+
s
g
(
z
i
^
−
z
i
)
)
∥
2
+
β
(
1
T
∑
i
=
1
T
[
∥
s
g
(
z
i
)
−
z
i
^
∥
2
+
α
∥
z
i
−
s
g
(
z
i
^
)
∥
2
]
)
L=\frac{1}{T}\sum_{i=1}^T\|x_i-decoder(z_i+sg(\hat{z_i}-z_i))\|^2+\beta (\frac{1}{T}\sum_{i=1}^T[\|\mathop{sg}(z_i)-\hat{z_i}\|^2+\alpha\|z_i-\mathop{sg}(\hat{z_i})\|^2])
L=T1i=1∑T∥xi−decoder(zi+sg(zi^−zi))∥2+β(T1i=1∑T[∥sg(zi)−zi^∥2+α∥zi−sg(zi^)∥2])
通过这个损失,可以完成重构、为
z
i
z_i
zi找到最接近的码字
e
k
i
e_{k_i}
eki、更新码本中的码字这三个目标;这三个目标分别对应于损失函数中的三项。
4. 更新码本中码字的其他方法 指数移动平均值(EMA)
- 除了使用前述的损失函数 1 T ∑ i = 1 T ∥ s g ( z i ) − z i ^ ∥ 2 \frac{1}{T}\sum_{i=1}^T\|\mathop{sg}(z_i)-\hat{z_i}\|^2 T1∑i=1T∥sg(zi)−zi^∥2更新码字以外,还可以使用指数移动平均值(exponential moving average, EMA)更新码本
- 指数移动平均值(exponential moving average, EMA)
(1). 假设与码字 e i e_i ei距离最接近的 n i n_i ni个连续变量为 { z i , 1 , z i , 2 , … , z i , n i } \{z_{i,1},z_{i,2},\dots,z_{i,n_i}\} {zi,1,zi,2,…,zi,ni},则损失函数还可以写作 min e i ∑ j = 1 n i ∥ z i , j − e i ∥ 2 \mathop{\min}_{e_i}\sum_{j=1}^{n_i}\|z_{i,j}-e_i\|^2 mineij=1∑ni∥zi,j−ei∥2上述优化问题的最优解 e i e_i ei有闭式解,即集合中元素的平均值 e i = 1 n i ∑ j = 1 n i z i , j e_i=\frac{1}{n_i}\sum_{j=1}^{n_i}z_{i,j} ei=ni1∑j=1nizi,j。这种更新方法类似于KMeans中类别中心的更新方法。
(2).但是在实际训练中,由于使用了minibatches方法,无法直接在一个batch中获得与 e i e_i ei距离最接近的所有 n i n_i ni个连续变量,因此不能直接使用上述求均值的方法。但是可以使用EMA作为代替,具体算法是:
N i ( t ) = N i ( t − 1 ) ∗ γ + n i ( t ) ( 1 − γ ) N_i^{(t)}=N_i^{(t-1)}*\gamma+n_i^{(t)}(1-\gamma) Ni(t)=Ni(t−1)∗γ+ni(t)(1−γ) m i ( t ) = m i ( t − 1 ) ∗ γ + ∑ j z i , j ( t ) ( 1 − γ ) m_i^{(t)}=m_i^{(t-1)}*\gamma+\sum_jz_{i,j}^{(t)}(1-\gamma) mi(t)=mi(t−1)∗γ+j∑zi,j(t)(1−γ) e i ( t ) = m i ( t ) N i ( t ) e_i^{(t)}=\frac{m_i^{(t)}}{N_i^{(t)}} ei(t)=Ni(t)mi(t)
其中 0 < γ < 1 0<\gamma<1 0<γ<1; N i N_i Ni是指所有batches中码字 e i e_i ei对应的 { z i , 1 , z i , 2 , … , z i , n i } \{z_{i,1},z_{i,2},\dots,z_{i,n_i}\} {zi,1,zi,2,…,zi,ni}的元素数量, n i ( t ) n_i(t) ni(t)表示当前batch中与码字 e i e_i ei对应的 { z i , 1 , z i , 2 , … , z i , n i } \{z_{i,1},z_{i,2},\dots,z_{i,n_i}\} {zi,1,zi,2,…,zi,ni}的元素数量; m i m_i mi 表示所有batches中与码字 e i e_i ei对应的元素 z i , 1 , z i , 2 , … , z i , n i {z_{i,1},z_{i,2},\dots,z_{i,n_i}} zi,1,zi,2,…,zi,ni的求和值, ∑ j z i , j ( t ) \sum_jz_{i,j}^{(t)} ∑jzi,j(t)表示当前batch中与码字 e i e_i ei对应的元素 z i , 1 , z i , 2 , … , z i , n i {z_{i,1},z_{i,2},\dots,z_{i,n_i}} zi,1,zi,2,…,zi,ni的求和值。 - 在使用EMA更新码本的情况下,整个网络的损失可以写作: L = 1 T ∑ i = 1 T ∥ x i − d e c o d e r ( z i + s g ( z i ^ − z i ) ) ∥ 2 + α ∥ z i − s g ( z i ^ ) ∥ 2 ] L=\frac{1}{T}\sum_{i=1}^T\|x_i-decoder(z_i+sg(\hat{z_i}-z_i))\|^2+\alpha\|z_i-\mathop{sg}(\hat{z_i})\|^2] L=T1i=1∑T∥xi−decoder(zi+sg(zi^−zi))∥2+α∥zi−sg(zi^)∥2] 通过这个损失,可以完成重构、为 z i z_i zi找到最接近的码字 e k i e_{k_i} eki这两个目标;这两个目标分别对应于损失函数中的两项。而更新码本的目标由EMA完成。
三、代码解读
1.使用EMA作为更新码本的方法时,VQ部分的代码
代码参考链接:VQ-EMA
class VQEmbeddingEMA(nn.Module):
def __init__(self, n_embeddings, embedding_dim, commitment_cost=0.25, decay=0.999, epsilon=1e-5):
"""
n_embeddings: 码本大小,即码字总数
embedding_dim: 每个码字的维度
commitment_cost: commitment loss前的系数,即commitment cost
decay: EMA更新公式中的\gamma
epsilon: 防止除数为0
"""
super(VQEmbeddingEMA, self).__init__()
self.commitment_cost = commitment_cost
self.decay = decay
self.epsilon = epsilon
### 初始化码本 ###
init_bound = 1 / n_embeddings
embedding = torch.Tensor(n_embeddings, embedding_dim)
# 从均匀分布U[-1/n_embeddings,1/n_embeddings]中抽样数值对tensor进行填充
embedding.uniform_(-init_bound, init_bound)
### 设置一些参数 ###
# self.register_buffer('name', Tensor)定义一组名为name的参数,该组参数的特别之处在于:
# 调用optimizer.step()后该组参数不会变化,只可人为地改变它们的值;
# 但是保存模型时,该组参数又作为模型参数不可或缺的一部分被保存。
self.register_buffer("embedding", embedding)
self.register_buffer("ema_count", torch.zeros(n_embeddings))
self.register_buffer("ema_weight", self.embedding.clone())
def forward(self, x):
"""
输入:
x: [B, T, D],为连续向量序列
----------------------------------
输出:
quantized: 离散化后的序列,即\hat{z}
loss: VQ Loss
"""
K, D = self.embedding.size() # K表示码字总数/码本大小,D表示码字维度
x_flat = x.detach().reshape(-1, D) # x:[B,T,D]->x_flat:[BxT,D]
# torch.addmm(M,M1,M2,a,b) = bM+a(M1@M2), 其中M1@M2表示矩阵乘法
# 计算序列x和码本中各码字之间的距离
distances = torch.addmm(torch.sum(self.embedding ** 2, dim=1) +
torch.sum(x_flat ** 2, dim=1, keepdim=True),
x_flat, self.embedding.t(),
alpha=-2.0, beta=1.0)
# 选择距离最近的码字,获得的indices为相应码字的索引序列
indices = torch.argmin(distances.float(), dim=-1)
# F.one_hot(indices, K)是对indices进行one-hot编码
# 例,F.one_hot([5,3,2,4,1], 6)将[5,3,2,4,1]编码为
# [[0,0,0,0,0,1]
# [0,0,0,1,0,0]
# [0,0,1,0,0,0]
# [0,0,0,0,1,0]
# [0,1,0,0,0,0]]
encodings = F.one_hot(indices, K).float() # encodings为索引序列indices的one-hot编码
### 获得相应的码字 ###
# F.embedding(indices, self.embedding)用于使用索引indices在固定码本self.embedding中检索码字
quantized = F.embedding(indices, self.embedding) # quantized为检索到的相应的码字
quantized = quantized.view_as(x) # [BxT,D]->[B,T,D]
### 使用EMA方法更新码本 ###
if self.training:
# self.ema_count即为EMA更新公式中的N(t),其中的第i个元素表示所有数据中与第i个码字对应的连续变量x_i的数量
# torch.sum(encodings, dim=0)即EMA更新公式中的n(t),其中的第i个元素表示当前batch中与第i个码字对应的连续变量x_i的数量
self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0)
n = torch.sum(self.ema_count)
self.ema_count = (self.ema_count + self.epsilon) / (n + D * self.epsilon) * n
# dw即EMA更新公式中的\sum{z_{i,j}},第i个元素即当前batch中与第i个码字对应的连续元素的和
dw = torch.matmul(encodings.t(), x_flat)
#self.ema_weight即EMA更新公式中的m(t),第i个元素即所有batch中与第i个码字对应的连续变量的和
self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * dw
# 更新码本中的码字
self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1)
### 计算VQ Loss ###
# VQ损失,固定quantized(因为这一项已通过上述EMA方法更新),使x向quantized更靠近
e_latent_loss = F.mse_loss(x, quantized.detach())
loss = self.commitment_cost * e_latent_loss
### 使用stop-gradient operator, 便于反向传播计算梯度###
# .detach()即使用了stop-gradient operator,在反向传播的时候只计算对x的梯度
quantized = x + (quantized - x).detach()
return quantized, loss
2. 使用MSE损失更新码本时,VQ部分的代码
代码参考链接:VQ-MSE,为与上部分统一,有修改
class VQEmbedding(nn.Module):
def __init__(self, n_embeddings, embedding_dim, commitment_cost=0.25):
"""
n_embeddings: 码本大小,即码字总数
embedding_dim: 每个码字的维度
commitment_cost: 在总体损失函数中,commitment loss前的系数,即commitment cost
"""
super(VQEmbedding, self).__init__()
self.commitment_cost = commitment_cost
### 初始化码本 ###
init_bound = 1 / n_embeddings
self.embedding = nn.Embedding(n_embeddings, embedding_dim)
# 从均匀分布U[-1/512,1/512]中抽样数值对tensor进行填充
self.embedding.weight.data.uniform_(-init_bound, init_bound)
def forward(self, x):
"""
输入:
x: [B, T, D],为连续向量序列
----------------------------------
输出:
quantized: 离散化后的序列,即\hat{z}
loss: VQ Loss
"""
K, D = self.embedding.size() # K表示码字总数/码本大小,D表示码字维度
x_flat = x.detach().reshape(-1, D) # x:[B,T,D]->x_flat:[BxT,D]
# torch.addmm(M,M1,M2,a,b) = bM+a(M1@M2), 其中M1@M2表示矩阵乘法
# 计算序列x和码本中各码字之间的距离
distances = torch.addmm(torch.sum(self.embedding ** 2, dim=1) +
torch.sum(x_flat ** 2, dim=1, keepdim=True),
x_flat, self.embedding.t(),
alpha=-2.0, beta=1.0)
# 选择距离最近的码字,获得的indices为相应码字的索引序列
indices = torch.argmin(distances.float(), dim=-1)
### 获得相应的码字 ###
quantized = self.embedding(indices) # quantized为检索到的相应的码字
quantized = quantized.view_as(x) # [BxT,D]->[B,T,D]
if not self.training:
return quantized
### 计算VQ Loss ###
# VQ损失,固定x,使quantized向x更靠近
q_latent_loss = F.mse_loss(quantized, x.detach())
# VQ损失,固定quantized,使x向quantized更靠近,即commitment loss
e_latent_loss = F.mse_loss(x, quantized.detach())
# 整体VQ损失
loss = q_latent_loss + self.commitment_cost*e_latent_loss
### 使用stop-gradient operator, 便于反向传播计算梯度###
# .detach()即使用了stop-gradient operator,在反向传播的时候只计算对x的梯度
quantized = x + (quantized - x).detach()
return quantized, loss```