如何为不可导操作设计梯度?——聊聊 VQ-VAE 中的 Straight-Through Estimator(直通估计器)

如何为不可导操作设计梯度?——聊聊 VQ-VAE 中的 Straight-Through Estimator

在深度学习中,梯度是优化的核心驱动力。然而,有些操作天生不可导,比如离散化操作(例如 argmin),这就给模型训练带来了麻烦。今天我们要聊的,就是一种巧妙的解决方案——Straight-Through Estimator(直通估计器),它在 VQ-VAE(Vector Quantized Variational Autoencoder)中发挥了关键作用。

背景:普通的自编码器 vs VQ-VAE

先从普通的自编码器(Autoencoder)说起。假设输入是 ( x x x ),编码器(Encoder)将其映射为潜变量 ( z z z ),解码器(Decoder)再从 ( z z z ) 重构出 ( x ^ \hat{x} x^ )。训练时,我们通常最小化重构误差:

L = ∣ ∣ x − decoder ( z ) ∣ ∣ 2 2 L = ||x - \text{decoder}(z)||_2^2 L=∣∣xdecoder(z)22

这个损失函数简单明了,因为整个过程是可导的,梯度可以顺畅地通过反向传播更新编码器和解码器。

但在 VQ-VAE 中,事情变得复杂了。VQ-VAE 的核心思想是将连续的潜变量 ( z z z ) 量化为离散的 ( z q z_q zq ),具体来说是通过查找最近邻(nearest neighbor)操作实现的:

z q = argmin e i ∈ E ∣ ∣ z − e i ∣ ∣ 2 z_q = \text{argmin}_{e_i \in E} ||z - e_i||_2 zq=argmineiE∣∣zei2

这里的 ( E E E ) 是一个预定义的码本(codebook),包含若干离散向量 ( e i e_i ei )。重构时,解码器用的是 ( z q z_q zq ) 而不是 ( z z z )。按理说,损失函数应该是:

L = ∣ ∣ x − decoder ( z q ) ∣ ∣ 2 2 L = ||x - \text{decoder}(z_q)||_2^2 L=∣∣xdecoder(zq)22

但问题来了:( z q z_q zq ) 的生成过程包含 ( argmin \text{argmin} argmin ),这个操作不可导!如果直接用这个损失,反向传播的梯度到 ( z q z_q zq ) 就断了,没法更新编码器。这怎么办?

矛盾:目标与优化难度的权衡

我们的目标是让 ( ∣ ∣ x − decoder ( z q ) ∣ ∣ 2 2 ||x - \text{decoder}(z_q)||_2^2 ∣∣xdecoder(zq)22 ) 尽可能小,但这个目标不好优化,因为 ( z q z_q zq ) 的梯度没法算。而 ( ∣ ∣ x − decoder ( z ) ∣ ∣ 2 2 ||x - \text{decoder}(z)||_2^2 ∣∣xdecoder(z)22 ) 虽然容易优化,却不是我们想要的——毕竟 VQ-VAE 的核心在于离散化的 ( z q z_q zq ),而不是连续的 ( z z z )。

一个很“粗暴”的想法是把两个损失加起来:

L = ∣ ∣ x − decoder ( z ) ∣ ∣ 2 2 + ∣ ∣ x − decoder ( z q ) ∣ ∣ 2 2 L = ||x - \text{decoder}(z)||_2^2 + ||x - \text{decoder}(z_q)||_2^2 L=∣∣xdecoder(z)22+∣∣xdecoder(zq)22

但这并不理想。最小化 ( ∣ ∣ x − decoder ( z ) ∣ ∣ 2 2 ||x - \text{decoder}(z)||_2^2 ∣∣xdecoder(z)22 ) 会强制 ( z z z ) 去拟合 ( x x x ),这和我们离散化的初衷背道而驰,相当于引入了额外的约束。

Straight-Through Estimator:一个巧妙的解决办法

为了解决这个矛盾,VQ-VAE 引入了 Straight-Through Estimator(直通估计器),这个方法最早出现在 Bengio 的论文《Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation》中。名字听起来有点学术,但原理其实很简单:

  • 前向传播:用你想要的变量,哪怕它不可导。
  • 反向传播:自己为它设计一个梯度,继续传播。

在 VQ-VAE 中,目标函数被设计成这样:

L = ∣ ∣ x − decoder ( z + sg [ z q − z ] ) ∣ ∣ 2 2 L = ||x - \text{decoder}(z + \text{sg}[z_q - z])||_2^2 L=∣∣xdecoder(z+sg[zqz])22

这里的 ( sg \text{sg} sg) 是 stop gradient(停止梯度)的意思,表示 ( sg [ z q − z ] \text{sg}[z_q - z] sg[zqz] ) 这部分不参与梯度计算。让我们拆解一下这个公式:

  1. 前向传播时

    • ( z + ( z q − z ) = z q z + (z_q - z) = z_q z+(zqz)=zq ),所以 ( decoder ( z + sg [ z q − z ] ) \text{decoder}(z + \text{sg}[z_q - z]) decoder(z+sg[zqz]) ) 等价于 ( decoder ( z q ) \text{decoder}(z_q) decoder(zq) )。
    • 这意味着损失计算时,用的就是我们想要的 ( z q z_q zq ),完美符合目标。
  2. 反向传播时

    • 因为 ( sg [ z q − z ] \text{sg}[z_q - z] sg[zqz] ) 的梯度被“停止”,所以对这部分的导数是 0。
    • 于是,梯度计算时,( decoder ( z + sg [ z q − z ] ) \text{decoder}(z + \text{sg}[z_q - z]) decoder(z+sg[zqz]) ) 就退化成了 ( decoder ( z ) \text{decoder}(z) decoder(z) )。
    • 这允许梯度通过 ( z z z ) 传回编码器,编码器得以更新。

简单来说,Straight-Through Estimator 就像一个“障眼法”:前向传播时用 ( z q z_q zq ) 计算损失,反向传播时假装用的是 ( z z z ) 来传梯度。这样既保证了离散化的目标,又让优化变得可行。

更广的视角:自定义梯度的潜力

Straight-Through 的思想不仅限于 VQ-VAE,它其实为我们打开了一扇大门:我们可以为任何函数自定义梯度。比如:

  • ( x + sg [ relu ( x ) − x ] x + \text{sg}[\text{relu}(x) - x] x+sg[relu(x)x] ):
    • 前向传播时等价于 ( relu ( x ) \text{relu}(x) relu(x) );
    • 反向传播时梯度恒为 1(因为 ( sg [ relu ( x ) − x ] \text{sg}[\text{relu}(x) - x] sg[relu(x)x] ) 不提供梯度,梯度全来自 ( x x x ))。

用这种方法,我们可以随便为一个函数指定梯度,比如让 ( sin ⁡ ( x ) \sin(x) sin(x) ) 的梯度恒为 1,或者让某个复杂操作的梯度变成一个常数。至于这样做有没有用,就要看具体任务了——它可能是优化中的“奇技淫巧”,也可能是某种启发式设计的起点。

总结

Straight-Through Estimator 是 VQ-VAE 成功的关键,它用一种简单而优雅的方式绕过了 ( argmin \text{argmin} argmin ) 不可导的问题。通过在前向传播和反向传播中使用不同的“代理”,它既实现了离散化的目标,又保证了梯度优化的可行性。更重要的是,这种思想提醒我们:在深度学习中,梯度不一定是“天生的”,我们可以根据需求去设计它。

维护编码表(codebook)

VQ-VAE 不仅仅依赖于重构损失,还需要维护编码表(codebook),让量化后的 ( z q z_q zq ) 和编码器输出的 ( z z z ) 尽量接近。为此,损失函数增加了额外的项来约束 ( z z z ) 和 ( z q z_q zq ) 之间的距离,并通过调整比例(如 ( β \beta β ) 和 ( γ \gamma γ ))实现“让 ( z q z_q zq ) 更靠近 ( z z z )”的目标。

代码

import torch
import torch.nn as nn
import torch.nn.functional as F

# 设置随机种子以确保结果可重复
torch.manual_seed(42)

# 模拟编码器
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.fc = nn.Linear(input_dim, hidden_dim)

    def forward(self, x):
        return self.fc(x)

# 模拟解码器
class Decoder(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, z):
        return self.fc(z)

# 向量量化函数
def quantize(z, codebook):
    distances = torch.cdist(z.unsqueeze(0), codebook.unsqueeze(0)).squeeze(0)
    indices = torch.argmin(distances, dim=1)
    z_q = codebook[indices]
    return z_q

# Straight-Through Estimator
def straight_through_estimator(z, z_q):
    return z + (z_q - z).detach()

# 参数设置
input_dim = 10
hidden_dim = 4
output_dim = 10
num_codes = 5
batch_size = 2
beta = 0.25  # VQ-VAE 原论文中的默认值
gamma = 0.25 * beta  # gamma = 0.25 * beta

# 初始化模型和码本(码本作为可训练参数)
encoder = Encoder(input_dim, hidden_dim)
decoder = Decoder(hidden_dim, output_dim)
codebook = nn.Parameter(torch.randn(num_codes, hidden_dim))  # 可训练的码本

# 生成模拟输入数据
x = torch.randn(batch_size, input_dim)

# 定义优化器
optimizer = torch.optim.Adam(
    list(encoder.parameters()) + list(decoder.parameters()) + [codebook], lr=0.01
)

# 训练一步
def train_step(x):
    optimizer.zero_grad()
    
    # 编码器输出 z
    z = encoder(x)
    
    # 量化得到 z_q
    z_q = quantize(z, codebook)
    
    # Straight-Through Estimator
    z_st = straight_through_estimator(z, z_q)
    
    # 解码器重构
    x_hat = decoder(z_st)
    
    # 计算损失
    recon_loss = F.mse_loss(x_hat, x)  # 重构损失
    commit_loss = F.mse_loss(z_q, z.detach())  # sg[z] - z_q
    codebook_loss = F.mse_loss(z, z_q.detach())  # z - sg[z_q]
    
    # 总损失
    total_loss = recon_loss + beta * commit_loss + gamma * codebook_loss
    
    # 反向传播
    total_loss.backward()
    
    # 更新参数
    optimizer.step()
    
    return total_loss.item(), recon_loss.item(), commit_loss.item(), codebook_loss.item(), z, z_q

# 运行训练
total_loss, recon_loss, commit_loss, codebook_loss, z, z_q = train_step(x)

# 打印结果
print(f"Total Loss: {total_loss:.4f}")
print(f"Reconstruction Loss: {recon_loss:.4f}")
print(f"Commit Loss (beta * ||sg[z] - z_q||^2): {commit_loss:.4f}")
print(f"Codebook Loss (gamma * ||z - sg[z_q]||^2): {codebook_loss:.4f}")
print(f"z (Encoder output):\n{z}")
print(f"z_q (Quantized):\n{z_q}")

# 检查码本的梯度
print(f"Gradient for codebook:\n{codebook.grad}")

代码解释

1. 损失函数

损失函数为:

L = ∣ ∣ x − decoder ( z + sg [ z q − z ] ) ∣ ∣ 2 2 + β ∣ ∣ sg [ z ] − z q ∣ ∣ 2 2 + γ ∣ ∣ z − sg [ z q ] ∣ ∣ 2 2 L = ||x - \text{decoder}(z + \text{sg}[z_q - z])||_2^2 + \beta ||\text{sg}[z] - z_q||_2^2 + \gamma ||z - \text{sg}[z_q]||_2^2 L=∣∣xdecoder(z+sg[zqz])22+β∣∣sg[z]zq22+γ∣∣zsg[zq]22

  • 重构损失 (recon_loss):( ∣ ∣ x − decoder ( z q ) ∣ ∣ 2 2 ||x - \text{decoder}(z_q)||_2^2 ∣∣xdecoder(zq)22 ),通过 Straight-Through Estimator 实现。
  • 约束项 1 (commit_loss):( β ∣ ∣ sg [ z ] − z q ∣ ∣ 2 2 \beta ||\text{sg}[z] - z_q||_2^2 β∣∣sg[z]zq22 ),让 ( z q z_q zq ) 靠近固定的 ( z z z )。
  • 约束项 2 (codebook_loss):( γ ∣ ∣ z − sg [ z q ] ∣ ∣ 2 2 \gamma ||z - \text{sg}[z_q]||_2^2 γ∣∣zsg[zq]22 ),让 ( z z z ) 靠近固定的 ( z q z_q zq )。

其中:

  • ( sg [ z ] \text{sg}[z] sg[z] ) 表示 ( z . d e t a c h ( ) z.detach() z.detach() ),即 ( z z z ) 的梯度被停止。
  • ( sg [ z q ] \text{sg}[z_q] sg[zq] ) 表示 ( z q . d e t a c h ( ) z_q.detach() zq.detach() ),即 ( z q z_q zq ) 的梯度被停止。
  • ( β \beta β ) 和 ( γ \gamma γ ) 是超参数,( γ < β \gamma < \beta γ<β )(原论文建议 ( γ = 0.25 β \gamma = 0.25\beta γ=0.25β )),以强调“让 ( z q z_q zq ) 靠近 ( z z z )”。
2. 代码中的实现细节
  • 码本可训练:将 codebook 定义为 nn.Parameter,使其可以通过梯度更新。
  • 损失计算
    • recon_loss:基于 ( x x x ) 和 ( x h a t x_{hat} xhat ) 的均方误差。
    • commit_loss:( z q z_q zq ) 和 ( z . d e t a c h ( ) z.detach() z.detach()) 的距离,鼓励 ( z q z_q zq ) 靠近 ( z z z )。
    • codebook_loss:( z z z ) 和 ( z q . d e t a c h ( ) z_q.detach() zq.detach()) 的距离,鼓励 ( z z z ) 靠近 ( z q z_q zq )。
    • 总损失加权求和:( total_loss = recon_loss + β ⋅ commit_loss + γ ⋅ codebook_loss \text{total\_loss} = \text{recon\_loss} + \beta \cdot \text{commit\_loss} + \gamma \cdot \text{codebook\_loss} total_loss=recon_loss+βcommit_loss+γcodebook_loss )。
3. 前向传播与反向传播
  • 前向传播:重构损失基于 ( z q z_q zq )(通过 Straight-Through Estimator),而约束项直接计算 ( z z z ) 和 ( z q z_q zq ) 的距离。
  • 反向传播
    • ( commit_loss \text{commit\_loss} commit_loss ) 的梯度只更新 ( z q z_q zq )(因为 ( z z z ) 被 detach)。
    • ( codebook_loss \text{codebook\_loss} codebook_loss ) 的梯度只更新 ( z z z )(因为 ( z q z_q zq ) 被 detach)。
    • 重构损失的梯度通过 ( z z z ) 更新编码器。
4. 输出分析

运行代码后,你会看到:

  • ( z z z ) 和 ( z q z_q zq ) 的值,观察它们是否接近。
  • 三部分损失的数值,( commit_loss \text{commit\_loss} commit_loss ) 和 ( codebook_loss \text{codebook\_loss} codebook_loss ) 反映了 ( z z z ) 和 ( z q z_q zq ) 的接近程度。
  • 码本的梯度,确认它是否被更新(通过 ( β ∣ ∣ sg [ z ] − z q ∣ ∣ 2 2 \beta ||\text{sg}[z] - z_q||_2^2 β∣∣sg[z]zq22 ))。

为什么这样设计?

  1. ( z z z ) 和 ( z q z_q zq ) 不一定接近

    • 即使重构损失很小,( z z z ) 和 ( z q z_q zq ) 可能差别很大(因为 ( f ( z 1 ) = f ( z 2 ) f(z_1) = f(z_2) f(z1)=f(z2) ) 不意味着 ( z 1 = z 2 z_1 = z_2 z1=z2 ))。
    • 增加 ( ∣ ∣ z − z q ∣ ∣ 2 2 ||z - z_q||_2^2 ∣∣zzq22 ) 显式约束了两者的距离。
  2. 分解 ( ∣ ∣ z − z q ∣ ∣ 2 2 ||z - z_q||_2^2 ∣∣zzq22 )

    • ( ∣ ∣ z − z q ∣ ∣ 2 2 = ∣ ∣ sg [ z ] − z q ∣ ∣ 2 2 + ∣ ∣ z − sg [ z q ] ∣ ∣ 2 2 ||z - z_q||_2^2 = ||\text{sg}[z] - z_q||_2^2 + ||z - \text{sg}[z_q]||_2^2 ∣∣zzq22=∣∣sg[z]zq22+∣∣zsg[zq]22 ) )(对于梯度而言等价,但前向传播是两倍)。
    • 用 ( β \beta β ) 和 ( γ \gamma γ ) 加权,控制更新方向:
      • ( β > γ \beta > \gamma β>γ ) 鼓励 ( z q z_q zq ) 更主动靠近 ( z z z ),因为码本是自由的。
      • ( γ \gamma γ ) 较小,避免过度约束编码器。
  3. 实际意义

    • ( commit_loss \text{commit\_loss} commit_loss )(( β \beta β ) 项)也被称为“commitment loss”,确保编码器输出不会偏离码本太远。
    • ( codebook_loss \text{codebook\_loss} codebook_loss )(( γ \gamma γ ) 项)辅助调整 ( z z z ),但权重较低,避免干扰重构目标。

动手实验建议

  1. 调整 ( β \beta β ) 和 ( γ \gamma γ )

    • 试试 ( β = 1.0 , γ = 0.0 \beta = 1.0, \gamma = 0.0 β=1.0,γ=0.0 )(只让 ( z q z_q zq ) 靠近 ( z z z ))。
    • 对比 ( β = 0.25 , γ = 0.25 \beta = 0.25, \gamma = 0.25 β=0.25,γ=0.25 )(两者都起作用),观察 ( z z z ) 和 ( z q z_q zq ) 的接近程度。
  2. 去掉约束项

    • 将总损失改为仅包含 recon_loss,看看 ( z z z ) 和 ( z q z_q zq ) 的差距是否变大。

通过这个代码,你可以直观看到 VQ-VAE 如何通过损失设计维护编码表,让 ( z q z_q zq ) 和 ( z z z ) 保持一致。希望这能帮你更深入理解!

参考

https://www.spaces.ac.cn/archives/6760

轻松理解 VQ-VAE:首个提出 codebook 机制的生成模型

后记

2025年3月10日14点37分于上海,在Grok 3大模型辅助下完成。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值