如何为不可导操作设计梯度?——聊聊 VQ-VAE 中的 Straight-Through Estimator
在深度学习中,梯度是优化的核心驱动力。然而,有些操作天生不可导,比如离散化操作(例如 argmin
),这就给模型训练带来了麻烦。今天我们要聊的,就是一种巧妙的解决方案——Straight-Through Estimator(直通估计器),它在 VQ-VAE(Vector Quantized Variational Autoencoder)中发挥了关键作用。
背景:普通的自编码器 vs VQ-VAE
先从普通的自编码器(Autoencoder)说起。假设输入是 ( x x x ),编码器(Encoder)将其映射为潜变量 ( z z z ),解码器(Decoder)再从 ( z z z ) 重构出 ( x ^ \hat{x} x^ )。训练时,我们通常最小化重构误差:
L = ∣ ∣ x − decoder ( z ) ∣ ∣ 2 2 L = ||x - \text{decoder}(z)||_2^2 L=∣∣x−decoder(z)∣∣22
这个损失函数简单明了,因为整个过程是可导的,梯度可以顺畅地通过反向传播更新编码器和解码器。
但在 VQ-VAE 中,事情变得复杂了。VQ-VAE 的核心思想是将连续的潜变量 ( z z z ) 量化为离散的 ( z q z_q zq ),具体来说是通过查找最近邻(nearest neighbor)操作实现的:
z q = argmin e i ∈ E ∣ ∣ z − e i ∣ ∣ 2 z_q = \text{argmin}_{e_i \in E} ||z - e_i||_2 zq=argminei∈E∣∣z−ei∣∣2
这里的 ( E E E ) 是一个预定义的码本(codebook),包含若干离散向量 ( e i e_i ei )。重构时,解码器用的是 ( z q z_q zq ) 而不是 ( z z z )。按理说,损失函数应该是:
L = ∣ ∣ x − decoder ( z q ) ∣ ∣ 2 2 L = ||x - \text{decoder}(z_q)||_2^2 L=∣∣x−decoder(zq)∣∣22
但问题来了:( z q z_q zq ) 的生成过程包含 ( argmin \text{argmin} argmin ),这个操作不可导!如果直接用这个损失,反向传播的梯度到 ( z q z_q zq ) 就断了,没法更新编码器。这怎么办?
矛盾:目标与优化难度的权衡
我们的目标是让 ( ∣ ∣ x − decoder ( z q ) ∣ ∣ 2 2 ||x - \text{decoder}(z_q)||_2^2 ∣∣x−decoder(zq)∣∣22 ) 尽可能小,但这个目标不好优化,因为 ( z q z_q zq ) 的梯度没法算。而 ( ∣ ∣ x − decoder ( z ) ∣ ∣ 2 2 ||x - \text{decoder}(z)||_2^2 ∣∣x−decoder(z)∣∣22 ) 虽然容易优化,却不是我们想要的——毕竟 VQ-VAE 的核心在于离散化的 ( z q z_q zq ),而不是连续的 ( z z z )。
一个很“粗暴”的想法是把两个损失加起来:
L = ∣ ∣ x − decoder ( z ) ∣ ∣ 2 2 + ∣ ∣ x − decoder ( z q ) ∣ ∣ 2 2 L = ||x - \text{decoder}(z)||_2^2 + ||x - \text{decoder}(z_q)||_2^2 L=∣∣x−decoder(z)∣∣22+∣∣x−decoder(zq)∣∣22
但这并不理想。最小化 ( ∣ ∣ x − decoder ( z ) ∣ ∣ 2 2 ||x - \text{decoder}(z)||_2^2 ∣∣x−decoder(z)∣∣22 ) 会强制 ( z z z ) 去拟合 ( x x x ),这和我们离散化的初衷背道而驰,相当于引入了额外的约束。
Straight-Through Estimator:一个巧妙的解决办法
为了解决这个矛盾,VQ-VAE 引入了 Straight-Through Estimator(直通估计器),这个方法最早出现在 Bengio 的论文《Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation》中。名字听起来有点学术,但原理其实很简单:
- 前向传播:用你想要的变量,哪怕它不可导。
- 反向传播:自己为它设计一个梯度,继续传播。
在 VQ-VAE 中,目标函数被设计成这样:
L = ∣ ∣ x − decoder ( z + sg [ z q − z ] ) ∣ ∣ 2 2 L = ||x - \text{decoder}(z + \text{sg}[z_q - z])||_2^2 L=∣∣x−decoder(z+sg[zq−z])∣∣22
这里的 ( sg \text{sg} sg) 是 stop gradient(停止梯度)的意思,表示 ( sg [ z q − z ] \text{sg}[z_q - z] sg[zq−z] ) 这部分不参与梯度计算。让我们拆解一下这个公式:
-
前向传播时:
- ( z + ( z q − z ) = z q z + (z_q - z) = z_q z+(zq−z)=zq ),所以 ( decoder ( z + sg [ z q − z ] ) \text{decoder}(z + \text{sg}[z_q - z]) decoder(z+sg[zq−z]) ) 等价于 ( decoder ( z q ) \text{decoder}(z_q) decoder(zq) )。
- 这意味着损失计算时,用的就是我们想要的 ( z q z_q zq ),完美符合目标。
-
反向传播时:
- 因为 ( sg [ z q − z ] \text{sg}[z_q - z] sg[zq−z] ) 的梯度被“停止”,所以对这部分的导数是 0。
- 于是,梯度计算时,( decoder ( z + sg [ z q − z ] ) \text{decoder}(z + \text{sg}[z_q - z]) decoder(z+sg[zq−z]) ) 就退化成了 ( decoder ( z ) \text{decoder}(z) decoder(z) )。
- 这允许梯度通过 ( z z z ) 传回编码器,编码器得以更新。
简单来说,Straight-Through Estimator 就像一个“障眼法”:前向传播时用 ( z q z_q zq ) 计算损失,反向传播时假装用的是 ( z z z ) 来传梯度。这样既保证了离散化的目标,又让优化变得可行。
更广的视角:自定义梯度的潜力
Straight-Through 的思想不仅限于 VQ-VAE,它其实为我们打开了一扇大门:我们可以为任何函数自定义梯度。比如:
- (
x
+
sg
[
relu
(
x
)
−
x
]
x + \text{sg}[\text{relu}(x) - x]
x+sg[relu(x)−x] ):
- 前向传播时等价于 ( relu ( x ) \text{relu}(x) relu(x) );
- 反向传播时梯度恒为 1(因为 ( sg [ relu ( x ) − x ] \text{sg}[\text{relu}(x) - x] sg[relu(x)−x] ) 不提供梯度,梯度全来自 ( x x x ))。
用这种方法,我们可以随便为一个函数指定梯度,比如让 ( sin ( x ) \sin(x) sin(x) ) 的梯度恒为 1,或者让某个复杂操作的梯度变成一个常数。至于这样做有没有用,就要看具体任务了——它可能是优化中的“奇技淫巧”,也可能是某种启发式设计的起点。
总结
Straight-Through Estimator 是 VQ-VAE 成功的关键,它用一种简单而优雅的方式绕过了 ( argmin \text{argmin} argmin ) 不可导的问题。通过在前向传播和反向传播中使用不同的“代理”,它既实现了离散化的目标,又保证了梯度优化的可行性。更重要的是,这种思想提醒我们:在深度学习中,梯度不一定是“天生的”,我们可以根据需求去设计它。
维护编码表(codebook)
VQ-VAE 不仅仅依赖于重构损失,还需要维护编码表(codebook),让量化后的 ( z q z_q zq ) 和编码器输出的 ( z z z ) 尽量接近。为此,损失函数增加了额外的项来约束 ( z z z ) 和 ( z q z_q zq ) 之间的距离,并通过调整比例(如 ( β \beta β ) 和 ( γ \gamma γ ))实现“让 ( z q z_q zq ) 更靠近 ( z z z )”的目标。
代码
import torch
import torch.nn as nn
import torch.nn.functional as F
# 设置随机种子以确保结果可重复
torch.manual_seed(42)
# 模拟编码器
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(Encoder, self).__init__()
self.fc = nn.Linear(input_dim, hidden_dim)
def forward(self, x):
return self.fc(x)
# 模拟解码器
class Decoder(nn.Module):
def __init__(self, hidden_dim, output_dim):
super(Decoder, self).__init__()
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, z):
return self.fc(z)
# 向量量化函数
def quantize(z, codebook):
distances = torch.cdist(z.unsqueeze(0), codebook.unsqueeze(0)).squeeze(0)
indices = torch.argmin(distances, dim=1)
z_q = codebook[indices]
return z_q
# Straight-Through Estimator
def straight_through_estimator(z, z_q):
return z + (z_q - z).detach()
# 参数设置
input_dim = 10
hidden_dim = 4
output_dim = 10
num_codes = 5
batch_size = 2
beta = 0.25 # VQ-VAE 原论文中的默认值
gamma = 0.25 * beta # gamma = 0.25 * beta
# 初始化模型和码本(码本作为可训练参数)
encoder = Encoder(input_dim, hidden_dim)
decoder = Decoder(hidden_dim, output_dim)
codebook = nn.Parameter(torch.randn(num_codes, hidden_dim)) # 可训练的码本
# 生成模拟输入数据
x = torch.randn(batch_size, input_dim)
# 定义优化器
optimizer = torch.optim.Adam(
list(encoder.parameters()) + list(decoder.parameters()) + [codebook], lr=0.01
)
# 训练一步
def train_step(x):
optimizer.zero_grad()
# 编码器输出 z
z = encoder(x)
# 量化得到 z_q
z_q = quantize(z, codebook)
# Straight-Through Estimator
z_st = straight_through_estimator(z, z_q)
# 解码器重构
x_hat = decoder(z_st)
# 计算损失
recon_loss = F.mse_loss(x_hat, x) # 重构损失
commit_loss = F.mse_loss(z_q, z.detach()) # sg[z] - z_q
codebook_loss = F.mse_loss(z, z_q.detach()) # z - sg[z_q]
# 总损失
total_loss = recon_loss + beta * commit_loss + gamma * codebook_loss
# 反向传播
total_loss.backward()
# 更新参数
optimizer.step()
return total_loss.item(), recon_loss.item(), commit_loss.item(), codebook_loss.item(), z, z_q
# 运行训练
total_loss, recon_loss, commit_loss, codebook_loss, z, z_q = train_step(x)
# 打印结果
print(f"Total Loss: {total_loss:.4f}")
print(f"Reconstruction Loss: {recon_loss:.4f}")
print(f"Commit Loss (beta * ||sg[z] - z_q||^2): {commit_loss:.4f}")
print(f"Codebook Loss (gamma * ||z - sg[z_q]||^2): {codebook_loss:.4f}")
print(f"z (Encoder output):\n{z}")
print(f"z_q (Quantized):\n{z_q}")
# 检查码本的梯度
print(f"Gradient for codebook:\n{codebook.grad}")
代码解释
1. 损失函数
损失函数为:
L = ∣ ∣ x − decoder ( z + sg [ z q − z ] ) ∣ ∣ 2 2 + β ∣ ∣ sg [ z ] − z q ∣ ∣ 2 2 + γ ∣ ∣ z − sg [ z q ] ∣ ∣ 2 2 L = ||x - \text{decoder}(z + \text{sg}[z_q - z])||_2^2 + \beta ||\text{sg}[z] - z_q||_2^2 + \gamma ||z - \text{sg}[z_q]||_2^2 L=∣∣x−decoder(z+sg[zq−z])∣∣22+β∣∣sg[z]−zq∣∣22+γ∣∣z−sg[zq]∣∣22
- 重构损失 (
recon_loss
):( ∣ ∣ x − decoder ( z q ) ∣ ∣ 2 2 ||x - \text{decoder}(z_q)||_2^2 ∣∣x−decoder(zq)∣∣22 ),通过 Straight-Through Estimator 实现。 - 约束项 1 (
commit_loss
):( β ∣ ∣ sg [ z ] − z q ∣ ∣ 2 2 \beta ||\text{sg}[z] - z_q||_2^2 β∣∣sg[z]−zq∣∣22 ),让 ( z q z_q zq ) 靠近固定的 ( z z z )。 - 约束项 2 (
codebook_loss
):( γ ∣ ∣ z − sg [ z q ] ∣ ∣ 2 2 \gamma ||z - \text{sg}[z_q]||_2^2 γ∣∣z−sg[zq]∣∣22 ),让 ( z z z ) 靠近固定的 ( z q z_q zq )。
其中:
- ( sg [ z ] \text{sg}[z] sg[z] ) 表示 ( z . d e t a c h ( ) z.detach() z.detach() ),即 ( z z z ) 的梯度被停止。
- ( sg [ z q ] \text{sg}[z_q] sg[zq] ) 表示 ( z q . d e t a c h ( ) z_q.detach() zq.detach() ),即 ( z q z_q zq ) 的梯度被停止。
- ( β \beta β ) 和 ( γ \gamma γ ) 是超参数,( γ < β \gamma < \beta γ<β )(原论文建议 ( γ = 0.25 β \gamma = 0.25\beta γ=0.25β )),以强调“让 ( z q z_q zq ) 靠近 ( z z z )”。
2. 代码中的实现细节
- 码本可训练:将
codebook
定义为nn.Parameter
,使其可以通过梯度更新。 - 损失计算:
recon_loss
:基于 ( x x x ) 和 ( x h a t x_{hat} xhat ) 的均方误差。commit_loss
:( z q z_q zq ) 和 ( z . d e t a c h ( ) z.detach() z.detach()) 的距离,鼓励 ( z q z_q zq ) 靠近 ( z z z )。codebook_loss
:( z z z ) 和 ( z q . d e t a c h ( ) z_q.detach() zq.detach()) 的距离,鼓励 ( z z z ) 靠近 ( z q z_q zq )。- 总损失加权求和:( total_loss = recon_loss + β ⋅ commit_loss + γ ⋅ codebook_loss \text{total\_loss} = \text{recon\_loss} + \beta \cdot \text{commit\_loss} + \gamma \cdot \text{codebook\_loss} total_loss=recon_loss+β⋅commit_loss+γ⋅codebook_loss )。
3. 前向传播与反向传播
- 前向传播:重构损失基于 ( z q z_q zq )(通过 Straight-Through Estimator),而约束项直接计算 ( z z z ) 和 ( z q z_q zq ) 的距离。
- 反向传播:
- (
commit_loss
\text{commit\_loss}
commit_loss ) 的梯度只更新 (
z
q
z_q
zq )(因为 (
z
z
z ) 被
detach
)。 - (
codebook_loss
\text{codebook\_loss}
codebook_loss ) 的梯度只更新 (
z
z
z )(因为 (
z
q
z_q
zq ) 被
detach
)。 - 重构损失的梯度通过 ( z z z ) 更新编码器。
- (
commit_loss
\text{commit\_loss}
commit_loss ) 的梯度只更新 (
z
q
z_q
zq )(因为 (
z
z
z ) 被
4. 输出分析
运行代码后,你会看到:
- ( z z z ) 和 ( z q z_q zq ) 的值,观察它们是否接近。
- 三部分损失的数值,( commit_loss \text{commit\_loss} commit_loss ) 和 ( codebook_loss \text{codebook\_loss} codebook_loss ) 反映了 ( z z z ) 和 ( z q z_q zq ) 的接近程度。
- 码本的梯度,确认它是否被更新(通过 ( β ∣ ∣ sg [ z ] − z q ∣ ∣ 2 2 \beta ||\text{sg}[z] - z_q||_2^2 β∣∣sg[z]−zq∣∣22 ))。
为什么这样设计?
-
( z z z ) 和 ( z q z_q zq ) 不一定接近:
- 即使重构损失很小,( z z z ) 和 ( z q z_q zq ) 可能差别很大(因为 ( f ( z 1 ) = f ( z 2 ) f(z_1) = f(z_2) f(z1)=f(z2) ) 不意味着 ( z 1 = z 2 z_1 = z_2 z1=z2 ))。
- 增加 ( ∣ ∣ z − z q ∣ ∣ 2 2 ||z - z_q||_2^2 ∣∣z−zq∣∣22 ) 显式约束了两者的距离。
-
分解 ( ∣ ∣ z − z q ∣ ∣ 2 2 ||z - z_q||_2^2 ∣∣z−zq∣∣22 ):
- ( ∣ ∣ z − z q ∣ ∣ 2 2 = ∣ ∣ sg [ z ] − z q ∣ ∣ 2 2 + ∣ ∣ z − sg [ z q ] ∣ ∣ 2 2 ||z - z_q||_2^2 = ||\text{sg}[z] - z_q||_2^2 + ||z - \text{sg}[z_q]||_2^2 ∣∣z−zq∣∣22=∣∣sg[z]−zq∣∣22+∣∣z−sg[zq]∣∣22 ) )(对于梯度而言等价,但前向传播是两倍)。
- 用 (
β
\beta
β ) 和 (
γ
\gamma
γ ) 加权,控制更新方向:
- ( β > γ \beta > \gamma β>γ ) 鼓励 ( z q z_q zq ) 更主动靠近 ( z z z ),因为码本是自由的。
- ( γ \gamma γ ) 较小,避免过度约束编码器。
-
实际意义:
- ( commit_loss \text{commit\_loss} commit_loss )(( β \beta β ) 项)也被称为“commitment loss”,确保编码器输出不会偏离码本太远。
- ( codebook_loss \text{codebook\_loss} codebook_loss )(( γ \gamma γ ) 项)辅助调整 ( z z z ),但权重较低,避免干扰重构目标。
动手实验建议
-
调整 ( β \beta β ) 和 ( γ \gamma γ ):
- 试试 ( β = 1.0 , γ = 0.0 \beta = 1.0, \gamma = 0.0 β=1.0,γ=0.0 )(只让 ( z q z_q zq ) 靠近 ( z z z ))。
- 对比 ( β = 0.25 , γ = 0.25 \beta = 0.25, \gamma = 0.25 β=0.25,γ=0.25 )(两者都起作用),观察 ( z z z ) 和 ( z q z_q zq ) 的接近程度。
-
去掉约束项:
- 将总损失改为仅包含
recon_loss
,看看 ( z z z ) 和 ( z q z_q zq ) 的差距是否变大。
- 将总损失改为仅包含
通过这个代码,你可以直观看到 VQ-VAE 如何通过损失设计维护编码表,让 ( z q z_q zq ) 和 ( z z z ) 保持一致。希望这能帮你更深入理解!
参考
https://www.spaces.ac.cn/archives/6760
轻松理解 VQ-VAE:首个提出 codebook 机制的生成模型
后记
2025年3月10日14点37分于上海,在Grok 3大模型辅助下完成。