快速了解矢量量化Vector-Quantized(VQ)及相应代码


一、VQ的简单描述

VQ的目标是:将连续变量序列 z z z映射到离散变量序列 z ^ \hat{z} z^

  1. VQ包含一个可以训练的码本 E = { e 1 , e 2 , … , e K } E=\{e_1,e_2,\dots,e_K\} E={e1,e2,,eK},这个码本包含 K K K个不同的码字;
  2. VQ的输入是一个连续向量序列 z = < z 1 , z 2 , … , z T > z=<z_1,z_2,\dots,z_T> z=<z1,z2,,zT>, 每个 z i z_i zi会被映射到码本中的一个码字,映射规则为:
    k i = arg ⁡ min ⁡ j ∥ z i − e j ∥ 2 k_i = \mathop{\arg\min}_{j} \| z_{i}-e_j\|_2 ki=argminjziej2
    根据此规则, z i z_i zi会被映射为 e k i e_{k_i} eki
  3. VQ的输出为离散的量化序列 z ^ = < z 1 ^ , z 2 ^ , … , z T ^ > = < e k 1 , e k 2 , … , e k T > \hat{z}=<\hat{z_1},\hat{z_2},\dots,\hat{z_T}>=<e_{k_1},e_{k_2},\dots,e_{k_T}> z^=<z1^,z2^,,zT^>=<ek1,ek2,,ekT>

二、VQ的训练细节

1. 怎么训练 arg ⁡ min ⁡ \mathop{\arg\min} argmin

  1. arg ⁡ min ⁡ \mathop{\arg\min} argmin运算不可微, 因此无法直接计算梯度;但可以用straight-through estimator来近似估计梯度
  2. straight-through estimator
    straight-through estimator的思想很简单,就是前向传播的时候使用想要的变量(哪怕不可微),而反向传播求梯度的时候,为不可微的运算使用所设计的梯度。
    在VQ的训练过程中 z i z_i zi由于 arg ⁡ min ⁡ \mathop{\arg\min} argmin运算不可微,因此在反向传播时,将其梯度设计为:把VQ的输出(即可能是decoder的输入) z i ^ \hat{z_i} zi^的梯度直接复制给 z i z_i zi;可以理解为直接对 z i z_i zi求梯度,而不是先对 z i ^ \hat{z_i} zi^求梯度,再对 z i z_i zi求梯度。
  3. 由此,整个网络的重构损失本应为 L R e c o n = 1 T ∑ i = 1 T ∥ x i − d e c o d e r ( z i ^ ) ∥ 2 L_{Recon}=\frac{1}{T}\sum_{i=1}^T\|x_i-decoder(\hat{z_i})\|^2 LRecon=T1i=1Txidecoder(zi^)2,由于将对 z i ^ \hat{z_i} zi^的梯度设计为直接对 z i z_i zi求梯度,因此在反向传播时,重构损失等价于 L R e c o n = 1 T ∑ i = 1 T ∥ x i − d e c o d e r ( z i ) ∥ 2 L_{Recon}=\frac{1}{T}\sum_{i=1}^T\|x_i-decoder(z_i)\|^2 LRecon=T1i=1Txidecoder(zi)2

2. 怎么为 z i z_i zi找到最接近的码字 e k i e_{k_i} eki,同时更新码本中的码字?

  1. VQ损失函数可以表示如下,它有两个作用:促使每个 z i z_i zi选择距离最近的码字,同时自动更新码本中的码字 z ^ \hat{z} z^
    L V Q = 1 T ∑ i = 1 T ∥ z i − z i ^ ∥ 2 = 反 向 传 播 时 1 T ∑ i = 1 T [ ∥ z i − s g ( z i ^ ) ∥ 2 + ∥ s g ( z i ) − z i ^ ∥ 2 ] L_{VQ}=\frac{1}{T}\sum_{i=1}^T\|z_i-\hat{z_i}\|^2\overset{反向传播时}{=}\frac{1}{T}\sum_{i=1}^T[\|z_i-\mathop{sg}(\hat{z_i})\|^2+\|\mathop{sg}(z_i)-\hat{z_i}\|^2] LVQ=T1i=1Tzizi^2=T1i=1T[zisg(zi^)2+sg(zi)zi^2]
    其中 s g ( ⋅ ) sg(·) sg()表示stop gradient operator
  2. stop gradient operator
    stop gradient operator指“不要它的梯度”,即在前向计算时对 s g ( ⋅ ) sg(·) sg()中的该项正常计算,而在反向传播时不对此项求梯度。
    前述的重构损失也可以运用 s g ( ⋅ ) sg(·) sg()重写作:
    L R e c o n = 1 T ∑ i = 1 T ∥ x i − d e c o d e r ( z i + s g ( z i ^ − z i ) ) ∥ 2 L_{Recon}=\frac{1}{T}\sum_{i=1}^T\|x_i-decoder(z_i+sg(\hat{z_i}-z_i))\|^2 LRecon=T1i=1Txidecoder(zi+sg(zi^zi))2
    这样在前向计算loss时等价于 L R e c o n = 1 T ∑ i = 1 T ∥ x i − d e c o d e r ( z i ^ ) ∥ 2 L_{Recon}=\frac{1}{T}\sum_{i=1}^T\|x_i-decoder(\hat{z_i})\|^2 LRecon=T1i=1Txidecoder(zi^)2,而在反向传播求梯度时等价于 L R e c o n = 1 T ∑ i = 1 T ∥ x i − d e c o d e r ( z i ) ∥ 2 L_{Recon}=\frac{1}{T}\sum_{i=1}^T\|x_i-decoder(z_i)\|^2 LRecon=T1i=1Txidecoder(zi)2
  3. commitment loss/cost的解释
    s g ( ⋅ ) sg(·) sg()表示在反向传播求梯度时不对此项求梯度。因此 L V Q L_{VQ} LVQ在反向传播时可以等价的拆分为两部分:
    (1). ∥ z i − s g ( z i ^ ) ∥ 2 \|z_i-\mathop{sg}(\hat{z_i})\|^2 zisg(zi^)2因为仅对 z i z_i zi求梯度,目的是“让 z i z_i zi靠近 z i ^ \hat{z_i} zi^”,即为 z i z_i zi选择距离最近的码字 z i ^ = e k i \hat{z_i}=e_{k_i} zi^=eki;这一项称为commitment loss,表示“ z z z commit to z ^ \hat{z} z^的程度”。
    (2). ∥ s g ( z i ) − z i ^ ∥ 2 \|\mathop{sg}(z_i)-\hat{z_i}\|^2 sg(zi)zi^2因为仅对 z i ^ \hat{z_i} zi^求梯度,目的是“让 z i ^ \hat{z_i} zi^靠近 z i z_i zi”,即为更新码本中的码字 e k i = z i ^ e_{k_i}=\hat{z_i} eki=zi^
    因为在重构损失中, z z z要尽力保证重构效果,而 z ^ \hat{z} z^相对比较自由;因此在VQ损失中希望"让 z i ^ \hat{z_i} zi^靠近 z i z_i zi"多于" z i z_i zi靠近 z i ^ \hat{z_i} zi^",这一目标可以通过一个系数 α \alpha α ( α < 1 ) (\alpha<1) (α<1)调整两部分的比例来实现,这个 α \alpha α即称作commitment cost
    L V Q = 1 T ∑ i = 1 T [ ∥ s g ( z i ) − z i ^ ∥ 2 + α ∥ z i − s g ( z i ^ ) ∥ 2 ] L_{VQ}=\frac{1}{T}\sum_{i=1}^T[\|\mathop{sg}(z_i)-\hat{z_i}\|^2+\alpha\|z_i-\mathop{sg}(\hat{z_i})\|^2] LVQ=T1i=1T[sg(zi)zi^2+αzisg(zi^)2]

3. 整个网络的损失

L = L R e c o n + β L V Q = 1 T ∑ i = 1 T ∥ x i − d e c o d e r ( z i ^ ) ∥ 2 + β ( 1 T ∑ i = 1 T [ ∥ s g ( z i ) − z i ^ ∥ 2 + α ∥ z i − s g ( z i ^ ) ∥ 2 ] ) L=L_{Recon}+\beta L_{VQ}=\frac{1}{T}\sum_{i=1}^T\|x_i-decoder(\hat{z_i})\|^2+\beta (\frac{1}{T}\sum_{i=1}^T[\|\mathop{sg}(z_i)-\hat{z_i}\|^2+\alpha\|z_i-\mathop{sg}(\hat{z_i})\|^2]) L=LRecon+βLVQ=T1i=1Txidecoder(zi^)2+β(T1i=1T[sg(zi)zi^2+αzisg(zi^)2])
如前所述,为便于 arg ⁡ min ⁡ \mathop{\arg\min} argmin的梯度计算,可以利用stop-gradient operator将整体损失写作:
L = 1 T ∑ i = 1 T ∥ x i − d e c o d e r ( z i + s g ( z i ^ − z i ) ) ∥ 2 + β ( 1 T ∑ i = 1 T [ ∥ s g ( z i ) − z i ^ ∥ 2 + α ∥ z i − s g ( z i ^ ) ∥ 2 ] ) L=\frac{1}{T}\sum_{i=1}^T\|x_i-decoder(z_i+sg(\hat{z_i}-z_i))\|^2+\beta (\frac{1}{T}\sum_{i=1}^T[\|\mathop{sg}(z_i)-\hat{z_i}\|^2+\alpha\|z_i-\mathop{sg}(\hat{z_i})\|^2]) L=T1i=1Txidecoder(zi+sg(zi^zi))2+β(T1i=1T[sg(zi)zi^2+αzisg(zi^)2])
通过这个损失,可以完成重构、为 z i z_i zi找到最接近的码字 e k i e_{k_i} eki、更新码本中的码字这三个目标;这三个目标分别对应于损失函数中的三项。

4. 更新码本中码字的其他方法 指数移动平均值(EMA)

  1. 除了使用前述的损失函数 1 T ∑ i = 1 T ∥ s g ( z i ) − z i ^ ∥ 2 \frac{1}{T}\sum_{i=1}^T\|\mathop{sg}(z_i)-\hat{z_i}\|^2 T1i=1Tsg(zi)zi^2更新码字以外,还可以使用指数移动平均值(exponential moving average, EMA)更新码本
  2. 指数移动平均值(exponential moving average, EMA)
    (1). 假设与码字 e i e_i ei距离最接近的 n i n_i ni个连续变量为 { z i , 1 , z i , 2 , … , z i , n i } \{z_{i,1},z_{i,2},\dots,z_{i,n_i}\} {zi,1,zi,2,,zi,ni},则损失函数还可以写作 min ⁡ e i ∑ j = 1 n i ∥ z i , j − e i ∥ 2 \mathop{\min}_{e_i}\sum_{j=1}^{n_i}\|z_{i,j}-e_i\|^2 mineij=1nizi,jei2上述优化问题的最优解 e i e_i ei有闭式解,即集合中元素的平均值 e i = 1 n i ∑ j = 1 n i z i , j e_i=\frac{1}{n_i}\sum_{j=1}^{n_i}z_{i,j} ei=ni1j=1nizi,j。这种更新方法类似于KMeans中类别中心的更新方法。
    (2).但是在实际训练中,由于使用了minibatches方法,无法直接在一个batch中获得与 e i e_i ei距离最接近的所有 n i n_i ni个连续变量,因此不能直接使用上述求均值的方法。但是可以使用EMA作为代替,具体算法是:
    N i ( t ) = N i ( t − 1 ) ∗ γ + n i ( t ) ( 1 − γ ) N_i^{(t)}=N_i^{(t-1)}*\gamma+n_i^{(t)}(1-\gamma) Ni(t)=Ni(t1)γ+ni(t)(1γ) m i ( t ) = m i ( t − 1 ) ∗ γ + ∑ j z i , j ( t ) ( 1 − γ ) m_i^{(t)}=m_i^{(t-1)}*\gamma+\sum_jz_{i,j}^{(t)}(1-\gamma) mi(t)=mi(t1)γ+jzi,j(t)(1γ) e i ( t ) = m i ( t ) N i ( t ) e_i^{(t)}=\frac{m_i^{(t)}}{N_i^{(t)}} ei(t)=Ni(t)mi(t)
    其中 0 < γ < 1 0<\gamma<1 0<γ<1 N i N_i Ni是指所有batches中码字 e i e_i ei对应的 { z i , 1 , z i , 2 , … , z i , n i } \{z_{i,1},z_{i,2},\dots,z_{i,n_i}\} {zi,1,zi,2,,zi,ni}的元素数量, n i ( t ) n_i(t) ni(t)表示当前batch中与码字 e i e_i ei对应的 { z i , 1 , z i , 2 , … , z i , n i } \{z_{i,1},z_{i,2},\dots,z_{i,n_i}\} {zi,1,zi,2,,zi,ni}的元素数量; m i m_i mi 表示所有batches中与码字 e i e_i ei对应的元素 z i , 1 , z i , 2 , … , z i , n i {z_{i,1},z_{i,2},\dots,z_{i,n_i}} zi,1,zi,2,,zi,ni的求和值, ∑ j z i , j ( t ) \sum_jz_{i,j}^{(t)} jzi,j(t)表示当前batch中与码字 e i e_i ei对应的元素 z i , 1 , z i , 2 , … , z i , n i {z_{i,1},z_{i,2},\dots,z_{i,n_i}} zi,1,zi,2,,zi,ni的求和值。
  3. 在使用EMA更新码本的情况下,整个网络的损失可以写作: L = 1 T ∑ i = 1 T ∥ x i − d e c o d e r ( z i + s g ( z i ^ − z i ) ) ∥ 2 + α ∥ z i − s g ( z i ^ ) ∥ 2 ] L=\frac{1}{T}\sum_{i=1}^T\|x_i-decoder(z_i+sg(\hat{z_i}-z_i))\|^2+\alpha\|z_i-\mathop{sg}(\hat{z_i})\|^2] L=T1i=1Txidecoder(zi+sg(zi^zi))2+αzisg(zi^)2] 通过这个损失,可以完成重构、为 z i z_i zi找到最接近的码字 e k i e_{k_i} eki这两个目标;这两个目标分别对应于损失函数中的两项。而更新码本的目标由EMA完成。

三、代码解读

1.使用EMA作为更新码本的方法时,VQ部分的代码

代码参考链接:VQ-EMA

class VQEmbeddingEMA(nn.Module):
    def __init__(self, n_embeddings, embedding_dim, commitment_cost=0.25, decay=0.999, epsilon=1e-5):
        """
	    n_embeddings: 码本大小,即码字总数
	    embedding_dim: 每个码字的维度
	    commitment_cost: commitment loss前的系数,即commitment cost
	    decay: EMA更新公式中的\gamma
	    epsilon: 防止除数为0
	    """
        super(VQEmbeddingEMA, self).__init__()
        self.commitment_cost = commitment_cost
        self.decay = decay
        self.epsilon = epsilon
		
		### 初始化码本 ###
        init_bound = 1 / n_embeddings
        embedding = torch.Tensor(n_embeddings, embedding_dim)
        # 从均匀分布U[-1/n_embeddings,1/n_embeddings]中抽样数值对tensor进行填充 
        embedding.uniform_(-init_bound, init_bound) 
        
        ### 设置一些参数 ###
        # self.register_buffer('name', Tensor)定义一组名为name的参数,该组参数的特别之处在于:
        # 调用optimizer.step()后该组参数不会变化,只可人为地改变它们的值;
        # 但是保存模型时,该组参数又作为模型参数不可或缺的一部分被保存。
        self.register_buffer("embedding", embedding)
        self.register_buffer("ema_count", torch.zeros(n_embeddings))
        self.register_buffer("ema_weight", self.embedding.clone())
    
    def forward(self, x):
        """
		输入:
		x: [B, T, D],为连续向量序列
		----------------------------------
		输出:
		quantized: 离散化后的序列,即\hat{z}
		loss: VQ Loss
		"""
        K, D = self.embedding.size() # K表示码字总数/码本大小,D表示码字维度
        x_flat = x.detach().reshape(-1, D) # x:[B,T,D]->x_flat:[BxT,D]

        # torch.addmm(M,M1,M2,a,b) = bM+a(M1@M2), 其中M1@M2表示矩阵乘法
        # 计算序列x和码本中各码字之间的距离
        distances = torch.addmm(torch.sum(self.embedding ** 2, dim=1) +
                                    torch.sum(x_flat ** 2, dim=1, keepdim=True),
                                    x_flat, self.embedding.t(),
                                    alpha=-2.0, beta=1.0) 
                                    
        # 选择距离最近的码字,获得的indices为相应码字的索引序列
        indices = torch.argmin(distances.float(), dim=-1)

        # F.one_hot(indices, K)是对indices进行one-hot编码
        # 例,F.one_hot([5,3,2,4,1], 6)将[5,3,2,4,1]编码为
        # [[0,0,0,0,0,1]
        #  [0,0,0,1,0,0]
        #  [0,0,1,0,0,0]
        #  [0,0,0,0,1,0]
        #  [0,1,0,0,0,0]]
        encodings = F.one_hot(indices, K).float() # encodings为索引序列indices的one-hot编码

        ### 获得相应的码字 ###
        # F.embedding(indices, self.embedding)用于使用索引indices在固定码本self.embedding中检索码字
        quantized = F.embedding(indices, self.embedding) # quantized为检索到的相应的码字
        quantized = quantized.view_as(x) # [BxT,D]->[B,T,D]

        ### 使用EMA方法更新码本 ###
        if self.training:
            # self.ema_count即为EMA更新公式中的N(t),其中的第i个元素表示所有数据中与第i个码字对应的连续变量x_i的数量
            # torch.sum(encodings, dim=0)即EMA更新公式中的n(t),其中的第i个元素表示当前batch中与第i个码字对应的连续变量x_i的数量
            self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0) 
            n = torch.sum(self.ema_count)
            self.ema_count = (self.ema_count + self.epsilon) / (n + D * self.epsilon) * n
            # dw即EMA更新公式中的\sum{z_{i,j}},第i个元素即当前batch中与第i个码字对应的连续元素的和
            dw = torch.matmul(encodings.t(), x_flat) 
            #self.ema_weight即EMA更新公式中的m(t),第i个元素即所有batch中与第i个码字对应的连续变量的和
            self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * dw 
            # 更新码本中的码字
            self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1)

        ### 计算VQ Loss ###
        # VQ损失,固定quantized(因为这一项已通过上述EMA方法更新),使x向quantized更靠近
        e_latent_loss = F.mse_loss(x, quantized.detach())
        loss = self.commitment_cost * e_latent_loss

        ### 使用stop-gradient operator, 便于反向传播计算梯度###
        # .detach()即使用了stop-gradient operator,在反向传播的时候只计算对x的梯度
        quantized = x + (quantized - x).detach() 

        return quantized, loss

2. 使用MSE损失更新码本时,VQ部分的代码

代码参考链接:VQ-MSE,为与上部分统一,有修改

class VQEmbedding(nn.Module):
    def __init__(self, n_embeddings, embedding_dim, commitment_cost=0.25):
        """
        n_embeddings: 码本大小,即码字总数
        embedding_dim: 每个码字的维度
        commitment_cost: 在总体损失函数中,commitment loss前的系数,即commitment cost
        """
        super(VQEmbedding, self).__init__()
        self.commitment_cost = commitment_cost

        ### 初始化码本 ###
        init_bound = 1 / n_embeddings
        self.embedding = nn.Embedding(n_embeddings, embedding_dim)
        # 从均匀分布U[-1/512,1/512]中抽样数值对tensor进行填充 
        self.embedding.weight.data.uniform_(-init_bound, init_bound) 
    
    def forward(self, x):
        """
        输入:
        x: [B, T, D],为连续向量序列
        ----------------------------------
        输出:
        quantized: 离散化后的序列,即\hat{z}
        loss: VQ Loss
        """
        K, D = self.embedding.size() # K表示码字总数/码本大小,D表示码字维度
        x_flat = x.detach().reshape(-1, D) # x:[B,T,D]->x_flat:[BxT,D]

        # torch.addmm(M,M1,M2,a,b) = bM+a(M1@M2), 其中M1@M2表示矩阵乘法
        # 计算序列x和码本中各码字之间的距离
        distances = torch.addmm(torch.sum(self.embedding ** 2, dim=1) +
                                    torch.sum(x_flat ** 2, dim=1, keepdim=True),
                                    x_flat, self.embedding.t(),
                                    alpha=-2.0, beta=1.0) 
                                    
        # 选择距离最近的码字,获得的indices为相应码字的索引序列
        indices = torch.argmin(distances.float(), dim=-1)

        ### 获得相应的码字 ###
        quantized = self.embedding(indices) # quantized为检索到的相应的码字
        quantized = quantized.view_as(x) # [BxT,D]->[B,T,D]

        if not self.training:
            return quantized

        ### 计算VQ Loss ###
        # VQ损失,固定x,使quantized向x更靠近
        q_latent_loss = F.mse_loss(quantized, x.detach())
        # VQ损失,固定quantized,使x向quantized更靠近,即commitment loss
        e_latent_loss = F.mse_loss(x, quantized.detach())
        # 整体VQ损失
        loss = q_latent_loss + self.commitment_cost*e_latent_loss

        ### 使用stop-gradient operator, 便于反向传播计算梯度###
        # .detach()即使用了stop-gradient operator,在反向传播的时候只计算对x的梯度
        quantized = x + (quantized - x).detach() 

        return quantized, loss```
  • 6
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
矢量量化器(Vector quantizer)是一种用于数据压缩和数据量化的技术。在python中,可以使用scikit-learn库中的聚类算法来实现矢量量化器。 首先,我们需要导入scikit-learn库和需要的数据。假设我们有一个包含n个d维数据点的数据集X。 ```python from sklearn.cluster import KMeans # 导入数据 X = ... # 创建矢量量化器对象 quantizer = KMeans(n_clusters=k) # 应用矢量量化器进行聚类 quantizer.fit(X) # 获取聚类中心 centroids = quantizer.cluster_centers_ # 使用聚类中心来量化数据 quantized_data = quantizer.predict(X) ``` 在上面的代码中,我们首先创建了一个KMeans对象作为矢量量化器,其中n_clusters参数指定了聚类的个数。然后,我们使用fit()方法来对数据进行聚类,得到聚类中心。最后,我们使用predict()方法来将数据点量化到对应的聚类中心。 量化后的数据quantized_data是一组聚类标签,表示每个数据点属于哪个聚类中心。聚类中心centroids是一组d维向量,表示每个聚类的中心点。 在实际应用中,矢量量化器可以用于数据压缩、特征提取和数据聚类等任务。通过将数据点映射到聚类中心,可以减少数据的维度,并且保留了原始数据的一些特征。这可以用于降低数据存储和传输的成本,同时还可以在一定程度上提高算法的效率和准确性。 总之,使用python中的聚类算法,我们可以很方便地实现矢量量化器,并应用于各种数据分析和机器学习任务中。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值