ICML2023论文阅读记录 - Straightening Out the Straight-Through Estimator


论文:Straightening Out the Straight-Through Estimator: Overcoming Optimization Challenges in Vector Quantized Networks

原文地址:Straightening Out the Straight-Through Estimator: Overcoming Optimization Challenges in Vector Quantized Networks。本文是在阅读原文时的简要总结和记录。

>


Abstract

1.本文解决的任务

这项工作研究了使用直通估计(straight-through estimation)进行矢量量化(Vector Quantization, VQ)训练神经网络的挑战。

  • 本文发现训练不稳定的主要原因是模型嵌入和码向量分布之间的差异。
  • 本文确定了导致此问题的因素,包括码本梯度稀疏性和承诺损失(commitment loss)的不对称性质,这会导致码向量分配不对齐。
2.本文提出的方法
  • 码向量的仿射重新参数化
  • 引入了交替优化来减少直通估计引入的梯度误差
  • 提出了对承诺损失的改进,以确保码本表示和模型嵌入之间更好的对齐
3. 本文得到的结果
  • 这些优化方法改进了直通估计的数学近似,并最终改进了模型性能。
  • 本文方法在几种常见模型架构(例如 AlexNet、ResNet 和 ViT)上跨各种任务(包括图像分类和生成建模)均具有有效性。
4. 本文的代码和模型地址

https://minyoungg.github.io/vqtorch/


Introduction

1.动机
  • 直通估计会对训练产生负面影响,包括模型崩溃或“索引崩溃”,即在训练期间仅使用一小部分代码;
  • 先前工作提出了许多方法来减少索引崩溃的程度,但还没有工作研究训练不稳定的根本原因;
  • 本文旨在系统地研究模型崩溃的原因,并提供解决由不稳定优化引起的常见陷阱的方法.
2.本文贡献
  • 将承诺损失制定为分歧度量(a divergence measure),为理解VQ 网络提供了新的见解,以更好地理解分歧产生的原因;
  • 为了减少这种分歧,提出了码向量的仿射重新参数化,以更好地匹配嵌入表示的矩(moments),仅此一点就大大减少了模型崩溃;
  • 对现有优化技术进行了一系列改进,例如交替优化和同步承诺损失,这两种方法简单且更新规则在数学上更正确,从而比标准方法有所改进。

Related Work:先前解决码本崩溃的工作

随机采样(stochastic sampling) 和概率松弛(probabilistic relaxation)

SQ-VAE(ICML '22)认为确定性(determinism)是码本崩溃的主要原因,并建议从与负距离成比例的分类分布 p ( z q = c j ∣ z e ) ∝ e − d ( z e − c j ) / τ p(\mathbf{z}_q=\mathbf{c}_j|\mathbf{z}_e)\propto e^{-d(\mathbf{z}_e-\mathbf{c}_j)/\tau} p(zq=cjze)ed(zecj)/τ中采样代码。其中 τ \tau τ是表示温度的标量。通常对温度进行退火以使模型在收敛时具有确定性。随机采样可能是一个瓶颈(bottleneck),因为它需要计算和存储全距离矩阵。

Repeated K-Means

Robust Training of Vector Quantized Bottleneck Models(IJCNN '20)通过在每个epoch运行K-Means来明确确保所有码向量都处于活跃状态,即强制所有代码重新初始化。 因为编码器和解码器都必须重新调整以适应新引入的码向量,它可能会导致模型性能大幅上升;但当学习率衰减时,模型将不再适应新的码向量,并且性能会下降。

Replacement Policy

SoundStream(TASLP '21)和Jukebox('20)提出用随机采样的模型嵌入替换死码向量。 最少使用(LRU)替换策略:如果码向量没有用于20次训练迭代,则会被随机模型嵌入替换。当使用替换策略时,活码向量保持不变,并且模型的整体性能不会降低。


针对VQ网络可训练性的分析

1. 承诺损失是一种不对称损失

  • 记模型嵌入 z e \mathbf{z}_e ze、量化向量 z q \mathbf{z}_q zq和码向量 c \mathbf{c} c分别属于集合 P z , Q z \mathcal{P}_z, \mathcal{Q}_z Pz,Qz C z \mathcal{C}_z Cz,其中 Q z ⊆ C z \mathcal{Q}_z \subseteq \mathcal{C}_z QzCz。传统的承诺损失可以写作
    L c m t ( z e , z q ) = ( 1 − β ) ⋅ d ( z e , s g [ z q ] ) + β d ( s g [ z e ] , z q ) \mathcal{L}_{cmt}(\mathbf{z}_e, \mathbf{z}_q)=(1-\beta)\cdot d(\mathbf{z}_e, sg[\mathbf{z}_q])+\beta d(sg[\mathbf{z}_e], \mathbf{z}_q) Lcmt(ze,zq)=(1β)d(ze,sg[zq])+βd(sg[ze],zq)
    其中 β ∈ [ 0 , 1 ] \beta \in [0, 1] β[0,1] 是一个标量,它权衡更新 z e \mathbf{z}_e ze z q \mathbf{z}_q zq的重要性(例如,大的 β \beta β意味着更注重更新码本以适应编码器)。
  • 承诺损失又可重写为在 P z \mathcal{P}_z Pz C z \mathcal{C}_z Cz 中一组对齐的点之间计算的距离 d ( ⋅ ) d(\cdot) d()的平均值:
    min ⁡ C z D ( P z , C z ) = 1 ∣ P z ∣ ∑ z i ∈ P z min ⁡ c j ∈ C z d ( z i , c j ) \mathop{\min}_{\mathcal{C}_z} D(\mathcal{P}_z, \mathcal{C}_z)=\frac{1}{|\mathcal{P}_z|}\sum_{\mathbf{z}_i\in \mathcal{P}_z} \mathop{\min}_{\mathbf{c}_j\in \mathcal{C}_z}d(\mathbf{z}_i, \mathbf{c}_j) minCzD(Pz,Cz)=Pz1ziPzmincjCzd(zi,cj)
    d ( ⋅ ) d(\cdot) d()是 Bregman 散度时(例如l2),则距离 D ( P z , C z ) D(\mathcal{P}_z, \mathcal{C}_z) D(Pz,Cz)可以看作是一组对齐的 P z \mathcal{P}_z Pz C z \mathcal{C}_z Cz上的平均散度。该散度函数是非对称的,在 P z \mathcal{P}_z Pz上计算,但相对于 C z \mathcal{C}_z Cz最小化。
  • 该散度导致多对一映射,其中一组选定的码字形成 Q z \mathcal{Q}_z Qz。 此处不相交的集合 C z ∖ Q z \mathcal{C}_z \setminus \mathcal{Q}_z CzQz没有接收任何梯度并且没有经过训练
  • 由于码字子集 Q \mathcal{Q} Q(而不是 C \mathcal{C} C)用于最小化承诺损失,因此 Q \mathcal{Q} Q学习了“模式搜索(mode-seeking)”行为。 这意味着一旦码字未被选择,它们将来可能会保持未被选择的状态。 请注意,即使码向量被初始化为在分布上与模型嵌入完美重叠,码向量也可能在优化过程中由于各种原因而被丢弃,包括训练中的随机性和非平稳模型表示 P z \mathcal{P}_z Pz

2. 梯度估计间隙

  • 记量化向量为 z q = z e + ϵ \mathbf{z}_q = \mathbf{z}_e + \epsilon zq=ze+ϵ,其中 ϵ \epsilon ϵ是由量化函数 h ( ⋅ ) h(\cdot) h()产生的残余误差向量。 当 ϵ = 0 \epsilon=0 ϵ=0时,则不存在直通估计误差,模型等同于没有量化函数(无损量化函数)。
  • 为了测量与无损量化函数的梯度偏差,梯度间隙(gradient gap)定义为:
    在这里插入图片描述
    梯度间隙测量非量化模型和量化模型的梯度之间的差异。 当 Δgap = 0 时,使用直通估计的梯度下降可以保证损失最小化。
  • 当(1)量化误差较小且(2)解码器函数G(·)平滑时,可以使该梯度间隙较小。 考虑当 z e \mathbf{z}_e ze z q \mathbf{z}_q zq相等 ϵ = 0 \epsilon=0 ϵ=0时的情况,则梯度间隙为 Δgap = 0。当它们不相等且 G(·) 是 K-Lipschitz 平滑时,估计误差按比例受到量化误差 K ⋅ d ( z e , z q ) K\cdot d(\mathbf{z}_e, \mathbf{z}_q) Kd(ze,zq)的限制。
  • 梯度间隙提供了直通估计优劣的衡量标准, 然而在实践中使用它时需要注意一个警告: 当VQ经历索引崩溃时Δgap 会急剧上升,但最终对于模型来说实现 Δgap = 0 变得非常简单,因为编码器被引导通过承诺损失来预测剩余的少数码字;活码向量越少,预测就越容易。 因此在使用梯度间隙作为比较模型的指标时应该谨慎。

Proposed Methods

1. 通过共享仿射参数化最小化内部码本协变量偏移(Minimizing internal codebook covariate shift with shared affine parameterization)

  • 内部码本协变量移位: 不断变化的内部表示 P z \mathcal{P}_z Pz会与码本分布 C z \mathcal{C}_z Cz产生偏差。
  • 当内部表示 P z \mathcal{P}_z Pz更新时,不仅 C z \mathcal{C}_z Cz需要更长的时间才能赶上(码本接受稀疏梯度),而且如果 P z \mathcal{P}_z Pz的更新太大(例如较大的学习率),分配可能会严重错位。
  • 提出具有共享全局均值和标准差的码向量的仿射重新参数化 在这里插入图片描述
    其中 c s i g n a l ( i ) \mathbf{c}_{signal}^{(i)} csignal(i)是原始码向量, c m e a n , c s t d \mathbf{c}_{mean},\mathbf{c}_{std} cmeancstd是具有相同 d i m ( c s i g n a l ( i ) ) dim(\mathbf{c}_{signal}^{(i)}) dim(csignal(i))的共享仿射参数 。仿射参数可以通过梯度下降来学习,也可以通过 z e \mathbf{z}_e ze z q \mathbf{z}_q zq统计数据的指数移动平均值来计算。重新参数化允许梯度通过仿射参数流过未选择的码向量。

2. 交替优化

  • 使用直通估计计算出的梯度是对真实梯度的有偏估计,可能会导致不良的优化动态,引起估计梯度偏离真实梯度,该偏离与量化误差成正比。 因此,当量化误差很大时更新网络可能会导致错误的模型更新。
  • 为了减少量化误差,提出一种交替优化算法:
    在这里插入图片描述
    上面的算法类似于在线 K-Means算法,1式优化 K-Means聚类,2式在给定新聚类分配的情况下优化模型。当 L c o m m i t → 0 \mathcal{L}_{commit}\to 0 Lcommit0 时, h ( ⋅ ) h(\cdot) h()充当恒等函数;然后,在固定 h h h下,F和G都可以以接近于零的估计误差进行优化。

3. 同步更新准则

  • 使用承诺损失更新的码本是模型表示的历史平均值,不考虑当前的代表性; β = 1 \beta=1 β=1时承诺损失的梯度更新可以写作:
    在这里插入图片描述
    因此,计算相对于历史平均值的梯度意味着模型收到“延迟”梯度。
  • 为了减少 z q \mathbf{z}_q zq表示的延迟,码向量应包含最新表示的运行平均值:
    在这里插入图片描述
    由于 z q \mathbf{z}_q zq的梯度用于更新 z e \mathbf{z}_e ze,因此同步更新规则的显式方程为:
    在这里插入图片描述
    使用上面的等式,码向量使用任务损失的梯度在编码器表示的方向上迈出一步。
  • 在Python中,这需要对直通估计的现有实现进行微小的改变:
    在这里插入图片描述
    其中 ν \nu ν是一个标量,用于决定pessimistic更新还是optimistic更新。 ν \nu ν的有效性取决于模型架构。

Experiment

1. 分类

  • 数据集:ImageNet100
  • 指标:测试数据集上模型的困惑度(定义为 2 H ( p ) 2^{H(p)} 2H(p),其中 H ( p ) H(p) H(p) 是码本似然的熵)。较高的困惑度意味着码字的统一分配。 虽然具有非常低的困惑度与码本崩溃相关,但拥有更高的困惑度并不一定意味着更好的性能——在具有比必要的码字更多的任务上具有高困惑度表明冗余。
  • 实验结果:
    • 使用仿射重新参数化的结果很大程度上改善了索引崩溃,并且使用同步和交替训练方法进一步提高了模型的整体性能。
    • 与最近最少使用(LRU)替换策略的使用进行比较,本文方法表现优于或表现相当。
    • 所有这些方法的结合可以最大程度地提高性能。
    • 使用 l2 标准化会损害分类性能。怀疑删除嵌入的幅度分量会损害使用幅度敏感目标的模型(例如软最大交叉熵损失)。
      在这里插入图片描述

2. 生成建模

  • 数据集:CelebA, CIFAR10
  • baselines: VQVAE,SVQ-VAE, Gumbel-VQVAE,其中SQVAE 需要比所有其他方法多 4 倍的内存占用,因为它需要存储完整距离矩阵以及计算图。
  • 结果
    • 使用 l2 归一化和最近最少使用 (LRU) 替换策略的两个基线在很大程度上提高了生成模型的训练稳定性和重建性能。
    • 当这些方法与本文方法联合应用时,我们观察到最好的改进。
      在这里插入图片描述

3. 预热(Warmup)和标准化(Normalization)可能会有所帮助

  • 减轻码本和模型嵌入分布之间的差异的一种方法是限制模型嵌入分布的移动范围,促进了整个训练过程中码本和嵌入分布之间的对齐;常见的技术包括L2归一化、批量归一化以及概率VQ 中假设受限分布。然而,这些技术以降低模型表达能力为代价提高了稳定性。
  • 另一种方法是确保 Pz 的更新较小以便码本能够跟上;使用具有预热功能的学习率调度器即可做到这一点。

4. 交替优化的消融实验

  • 本文测量了内部和外部循环迭代次数的变化如何影响分类性能;发现:通过将内部循环迭代次数增加 8 次,性能比基线提高了 11.09%,比使用内部步骤单次迭代的版本提高了 5.51%。 另一方面,我们没有发现增加外循环有帮助。 当组合我们所有的方法时,我们发现将内循环迭代设置为 1-2 就足够了。
    在这里插入图片描述

5. 进一步减少VQ的稀疏性

  • 为了进一步减少码本更新中的稀疏性,可以直接改进有助于稀疏性的架构设计选择。 具体来说,图像大小、批量大小和池化层数量等因素对 VQN 性能有显着影响,因为码向量选择的数量直接取决于这些变量。
  • 当将图像尺寸从 256 × 256 减小到 128 × 128 时,性能会显着下降,导致性能下降超过 20%。 这表明VQ训练设计选择的重要性。
  • 9
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值