Energy Matching中的训练目标分析

Energy Matching中的训练目标分析

在《Energy Matching: Unifying Flow Matching and Energy-Based Models for Generative Modeling》一文中,作者提出了一种新颖的生成模型框架——Energy Matching,通过结合最优传输(Optimal Transport, OT)和能量基础模型(Energy-Based Models, EBMs)的优势,实现高效的样本生成和显式的似然估计。本文将详细分析论文中2.1节“Training Objectives”部分,重点探讨Flow-like Objective (( L OT \mathcal{L}_{\text{OT}} LOT))Contrastive Objective (( L CD \mathcal{L}_{\text{CD}} LCD)) 以及 Dual Objective 的定义、作用和实现方式,并结合直观解释和上下文背景,帮助读者深入理解这些训练目标在Energy Matching框架中的重要性。


paper:https://arxiv.org/pdf/2504.10612

背景:Energy Matching的训练目标

具体背景知识可以参考笔者的另一篇博客:JKO方案中的一阶最优性条件与生成框架

Energy Matching的核心思想是通过学习一个时间无关的标量势函数 ( V θ ( x ) V_\theta(x) Vθ(x) ),实现从噪声分布到数据分布的平滑过渡。其训练过程分为两个阶段:

  1. 远离数据流形(Flow-like Regime):当样本距离数据流形较远时,利用最优传输的特性,通过确定性流快速将噪声样本引导到数据流形附近。此时,温度参数 ( ε ( t ) ≈ 0 \varepsilon(t) \approx 0 ε(t)0),训练目标主要由 ( L OT \mathcal{L}_{\text{OT}} LOT) 驱动。
  2. 接近数据流形(EBM-like Regime):当样本接近数据流形时,引入熵项(( ε ( t ) ≈ ε max \varepsilon(t) \approx \varepsilon_{\text{max}} ε(t)εmax)),通过对比散度(Contrastive Divergence)优化,使 ( V θ ( x ) V_\theta(x) Vθ(x) ) 形成一个Boltzmann分布,精确匹配数据分布。此时,训练目标结合 ( L OT \mathcal{L}_{\text{OT}} LOT) 和 ( L CD \mathcal{L}_{\text{CD}} LCD)。

2.1节详细描述了这两个训练目标以及它们的联合优化方式,以下逐一分析。


2.1.1 Flow-like Objective ( L OT \mathcal{L}_{\text{OT}} LOT)

定义

Flow-like Objective ( L OT \mathcal{L}_{\text{OT}} LOT) 的目标是构建一个全局速度场 ( ∇ x V θ ( x ) \nabla_x V_\theta(x) xVθ(x)),将噪声样本 ( { x 0 } \{x_0\} {x0}) 高效地运送到数据样本 ( { x data } \{x_{\text{data}}\} {xdata}),尽量减少路径上的“弯路”。为此,论文利用Wasserstein空间中的测地线(geodesics)来定义传输路径。具体而言,( L OT \mathcal{L}_{\text{OT}} LOT) 的损失函数定义为:

L OT = E t ∼ U ( 0 , 1 ) [ ∥ ∇ x V θ ( x t ( i ) ) + x data ( i ) − T ( x data ( i ) ) ∥ 2 ] \mathcal{L}_{\text{OT}} = \mathbb{E}_{t \sim U(0,1)} \left[ \left\| \nabla_x V_\theta(x_t^{(i)}) + x_{\text{data}}^{(i)} - T(x_{\text{data}}^{(i)}) \right\|^2 \right] LOT=EtU(0,1)[ xVθ(xt(i))+xdata(i)T(xdata(i)) 2]

其中:

  • ( x t ( i ) = ( 1 − t ) T ( x data ( i ) ) + t x data ( i ) x_t^{(i)} = (1-t) T(x_{\text{data}}^{(i)}) + t x_{\text{data}}^{(i)} xt(i)=(1t)T(xdata(i))+txdata(i) ):表示沿测地线的插值点,( t ∈ [ 0 , 1 ] t \in [0, 1] t[0,1] ),从噪声样本 ( T ( x data ( i ) ) T(x_{\text{data}}^{(i)}) T(xdata(i)) ) 到数据样本 ( x data ( i ) x_{\text{data}}^{(i)} xdata(i))。
  • ( T T T ):最优传输映射(OT map),通过OT求解器(如POT库)计算,定义了从噪声分布到数据分布的配对。
  • ( ∇ x V θ ( x t ( i ) ) \nabla_x V_\theta(x_t^{(i)}) xVθ(xt(i))):势函数 ( V θ ( x ) V_\theta(x) Vθ(x) ) 在插值点 ( x t ( i ) x_t^{(i)} xt(i) ) 的梯度,表示模型预测的速度场。
  • ( x data ( i ) − T ( x data ( i ) ) x_{\text{data}}^{(i)} - T(x_{\text{data}}^{(i)}) xdata(i)T(xdata(i)) ):目标速度,表示样本从噪声到数据的理想移动方向。

实现方式

  1. 数据准备

    • 从数据分布中采样一个mini-batch ( { x data ( i ) } i = 1 B \{x_{\text{data}}^{(i)}\}_{i=1}^B {xdata(i)}i=1B),表示真实数据样本。
    • 从噪声分布(通常为标准高斯分布 ( N ( 0 , I ) \mathcal{N}(0, I) N(0,I))) 采样等数量的噪声样本 ( { x 0 ( i ) } i = 1 B \{x_0^{(i)}\}_{i=1}^B {x0(i)}i=1B)。
    • 使用OT求解器计算最优传输映射 ( T T T ),将噪声样本 ( { x 0 ( i ) } \{x_0^{(i)}\} {x0(i)}) 配对到数据样本 ( { x data ( i ) } \{x_{\text{data}}^{(i)}\} {xdata(i)})。
  2. 插值与速度计算

    • 对于每个数据样本 ( x data ( i ) x_{\text{data}}^{(i)} xdata(i) ),通过线性插值生成路径上的点:
      x t ( i ) = ( 1 − t ) T ( x data ( i ) ) + t x data ( i ) x_t^{(i)} = (1-t) T(x_{\text{data}}^{(i)}) + t x_{\text{data}}^{(i)} xt(i)=(1t)T(xdata(i))+txdata(i)
    • 目标速度为 ( x data ( i ) − T ( x data ( i ) ) x_{\text{data}}^{(i)} - T(x_{\text{data}}^{(i)}) xdata(i)T(xdata(i)) ),表示样本以恒定速度从噪声移动到数据。
  3. 优化目标

    • 优化 ( L OT \mathcal{L}_{\text{OT}} LOT),使模型预测的梯度 ( − ∇ x V θ ( x t ( i ) ) -\nabla_x V_\theta(x_t^{(i)}) xVθ(xt(i))) 尽量接近目标速度 ( x data ( i ) − T ( x data ( i ) ) x_{\text{data}}^{(i)} - T(x_{\text{data}}^{(i)}) xdata(i)T(xdata(i)) )。
    • 期望操作 ( E t ∼ U ( 0 , 1 ) \mathbb{E}_{t \sim U(0,1)} EtU(0,1)) 确保损失在整个路径 ( [ 0 , 1 ] [0, 1] [0,1]) 上平均计算。

直观解释

( L OT \mathcal{L}_{\text{OT}} LOT) 的作用类似于“导航系统”,为样本从噪声到数据的移动提供方向。想象一群粒子(噪声样本)需要移动到目标位置(数据样本),( L OT \mathcal{L}_{\text{OT}} LOT) 确保这些粒子沿着最短路径(Wasserstein测地线)移动,并且模型预测的“推力” ( − ∇ x V θ ( x ) -\nabla_x V_\theta(x) xVθ(x)) 与理想路径一致。

这一目标的特别之处在于:

  • 无旋条件(Curl-free):由于速度场由标量势函数的梯度 ( ∇ x V θ ( x ) \nabla_x V_\theta(x) xVθ(x)) 定义,它是无旋的(curl-free),这与最优传输的特性一致,避免了不必要的旋转路径,降低了传输成本。
  • 时间无关:与传统Flow Matching方法使用时间依赖的速度场不同,( L OT \mathcal{L}_{\text{OT}} LOT) 假设速度场是静态的,仅依赖于 ( V θ ( x ) V_\theta(x) Vθ(x) ),简化了模型设计。

在框架中的作用

( L OT \mathcal{L}_{\text{OT}} LOT) 是第一阶段(Phase 1,Algorithm 1)的主要训练目标,确保噪声样本快速且高效地接近数据流形。通过预训练 ( V θ ( x ) V_\theta(x) Vθ(x) ) 以形成平滑的传输路径,( L OT \mathcal{L}_{\text{OT}} LOT) 为后续的对比散度优化提供了高质量的初始样本,避免了EBM训练中的模式崩塌问题。


2.1.2 Contrastive Objective ( L CD \mathcal{L}_{\text{CD}} LCD)

定义

Contrastive Objective ( L CD \mathcal{L}_{\text{CD}} LCD) 的目标是调整 ( V θ ( x ) V_\theta(x) Vθ(x) ),使平衡分布 ( ρ eq ( x ) ∝ exp ⁡ ( − V θ ( x ) ε max ) \rho_{\text{eq}}(x) \propto \exp\left(-\frac{V_\theta(x)}{\varepsilon_{\text{max}}}\right) ρeq(x)exp(εmaxVθ(x))) 精确匹配数据分布。论文采用经典的对比散度(Contrastive Divergence)损失,定义为:

L CD = E x ∼ ρ data [ V θ ( x ) ε max ] − E x ~ ∼ sg ( ρ eq ) [ V θ ( x ~ ) ε max ] \mathcal{L}_{\text{CD}} = \mathbb{E}_{x \sim \rho_{\text{data}}} \left[ \frac{V_\theta(x)}{\varepsilon_{\text{max}}} \right] - \mathbb{E}_{\tilde{x} \sim \text{sg}(\rho_{\text{eq}})} \left[ \frac{V_\theta(\tilde{x})}{\varepsilon_{\text{max}}} \right] LCD=Exρdata[εmaxVθ(x)]Ex~sg(ρeq)[εmaxVθ(x~)]

其中:

  • ( ρ data \rho_{\text{data}} ρdata):真实数据分布,样本 ( x x x ) 从中抽取。
  • ( ρ eq \rho_{\text{eq}} ρeq):由 ( V θ ( x ) V_\theta(x) Vθ(x) ) 诱导的平衡分布,负样本 ( x ~ \tilde{x} x~) 通过MCMC朗之万链(Langevin chain)近似采样。
  • ( sg ( ⋅ ) \text{sg}(\cdot) sg()):停止梯度操作(stop-gradient),确保梯度不通过采样过程反向传播。
  • ( ε max \varepsilon_{\text{max}} εmax):最大温度参数,控制Boltzmann分布的平滑程度。

实现方式

  1. 负样本采样

    • 使用朗之万动力学(Langevin Dynamics)生成负样本 ( x ~ \tilde{x} x~),更新公式为:
      x m + 1 = x m − Δ t ∇ x V θ ( x m ) + 2 Δ t ε ( m ) η , η ∼ N ( 0 , I ) x_{m+1} = x_m - \Delta t \nabla_x V_\theta(x_m) + \sqrt{2 \Delta t \varepsilon^{(m)}} \eta, \quad \eta \sim \mathcal{N}(0, I) xm+1=xmΔtxVθ(xm)+tε(m) η,ηN(0,I)
    • 初始样本分为两部分:
      • 一半从真实数据 ( x ∼ ρ data x \sim \rho_{\text{data}} xρdata ) 初始化,确保负样本探索数据流形附近的高密度区域。
      • 一半从噪声分布(如 ( N ( 0 , I ) \mathcal{N}(0, I) N(0,I))) 初始化,探索远离数据流形的区域,塑造全局能量景观。
  2. 损失计算

    • 正样本项:计算真实数据样本 ( x x x ) 的能量期望 ( E x ∼ ρ data [ V θ ( x ) ε max ] \mathbb{E}_{x \sim \rho_{\text{data}}} \left[ \frac{V_\theta(x)}{\varepsilon_{\text{max}}} \right] Exρdata[εmaxVθ(x)]),目标是降低数据点的能量。
    • 负样本项:计算负样本 ( x ~ \tilde{x} x~) 的能量期望 ( E x ~ ∼ sg ( ρ eq ) [ V θ ( x ~ ) ε max ] \mathbb{E}_{\tilde{x} \sim \text{sg}(\rho_{\text{eq}})} \left[ \frac{V_\theta(\tilde{x})}{\varepsilon_{\text{max}}} \right] Ex~sg(ρeq)[εmaxVθ(x~)]),目标是提高非数据点的能量。
    • 总损失 ( L CD \mathcal{L}_{\text{CD}} LCD) 通过对比正负样本的能量差,优化 ( V θ ( x ) V_\theta(x) Vθ(x) ) 以形成低能量的数据流形。

直观解释

( L CD \mathcal{L}_{\text{CD}} LCD) 类似于一个“雕刻师”,通过对比数据样本和负样本的能量,雕刻出能量函数 ( V θ ( x ) V_\theta(x) Vθ(x) ) 的形状:

  • 对于真实数据点 ( x x x ),( L CD \mathcal{L}_{\text{CD}} LCD) 试图降低其能量 ( V θ ( x ) V_\theta(x) Vθ(x) ),使数据区域成为能量“洼地”。
  • 对于负样本 ( x ~ \tilde{x} x~),( L CD \mathcal{L}_{\text{CD}} LCD) 试图提高其能量 ( V θ ( x ) V_\theta(x) Vθ(x) ),使非数据区域成为能量“高地”。
  • 朗之万动力学的采样过程模拟了粒子在能量景观中的随机游走,负样本的初始化策略(数据+噪声)确保能量函数既能精确建模数据流形,又能塑造全局结构。

在框架中的作用

( L CD \mathcal{L}_{\text{CD}} LCD) 是第二阶段(Phase 2,Algorithm 2)的重要组成部分,负责在数据流形附近精细调整 ( V θ ( x ) V_\theta(x) Vθ(x) ),形成Boltzmann分布 ( ρ eq ( x ) \rho_{\text{eq}}(x) ρeq(x))。通过对比散度,模型能够捕捉数据的局部密度结构,同时避免传统EBM训练中的模式崩塌问题(由于初始样本的质量由 ( L OT \mathcal{L}_{\text{OT}} LOT) 保证)。


2.1.3 Dual Objective

定义

为了平衡 ( L OT \mathcal{L}_{\text{OT}} LOT) 和 ( L CD \mathcal{L}_{\text{CD}} LCD) 两个目标,论文采用了一种双目标优化策略,通过线性温度调度(如下)协调两个阶段的训练:

ε ( t ) = { 0 , 0 ≤ t < τ ∗ , ε max t − τ ∗ 1 − τ ∗ , τ ∗ ≤ t ≤ 1 , ε max , t ≥ 1. \varepsilon(t) = \begin{cases} 0, & 0 \leq t < \tau^*, \\ \varepsilon_{\text{max}} \frac{t - \tau^*}{1 - \tau^*}, & \tau^* \leq t \leq 1, \\ \varepsilon_{\text{max}}, & t \geq 1. \end{cases} ε(t)= 0,εmax1τtτ,εmax,0t<τ,τt1,t1.

总损失函数为:

L ( θ ) = L OT + λ CD L CD \mathcal{L}(\theta) = \mathcal{L}_{\text{OT}} + \lambda_{\text{CD}} \mathcal{L}_{\text{CD}} L(θ)=LOT+λCDLCD

其中:

  • ( λ CD \lambda_{\text{CD}} λCD):数据集特定的超参数,用于平衡 ( L CD \mathcal{L}_{\text{CD}} LCD) 的贡献。
  • 温度 ( ε ( t ) \varepsilon(t) ε(t)):通过时间调度控制熵项的影响,早期为0(强调 ( L OT \mathcal{L}_{\text{OT}} LOT)),后期逐渐增加到 ( ε max \varepsilon_{\text{max}} εmax)(引入 ( L CD \mathcal{L}_{\text{CD}} LCD))。

实现方式

  1. Phase 1(预训练,Algorithm 1)

    • 仅使用 ( L OT \mathcal{L}_{\text{OT}} LOT) 优化 ( V θ ( x ) V_\theta(x) Vθ(x) ),( ε ( t ) = 0 \varepsilon(t) = 0 ε(t)=0)。
    • 目标是建立从噪声到数据的平滑传输路径,生成高质量的负样本,为后续对比散度优化奠定基础。
    • 训练过程包括采样噪声和数据、计算OT映射、优化速度场匹配。
  2. Phase 2(主训练,Algorithm 2)

    • 联合优化 ( L OT + λ CD L CD \mathcal{L}_{\text{OT}} + \lambda_{\text{CD}} \mathcal{L}_{\text{CD}} LOT+λCDLCD),随时间 ( t t t ) 增加 ( ε ( t ) \varepsilon(t) ε(t)) 至 ( ε max \varepsilon_{\text{max}} εmax)。
    • 对于负样本采样,朗之万动力学根据当前温度 ( ε ( m ) \varepsilon^{(m)} ε(m))(或 ( ε max \varepsilon_{\text{max}} εmax) 用于数据初始化的样本)进行更新。
    • 每次迭代计算 ( L OT \mathcal{L}_{\text{OT}} LOT) 和 ( L CD \mathcal{L}_{\text{CD}} LCD),通过梯度下降更新模型参数 ( θ \theta θ)。
  3. 超参数

    • 采样时间 ( τ s \tau_s τs):控制总采样步数,实验表明在CIFAR-10上 ( τ s = 3.0 \tau_s = 3.0 τs=3.0) 时生成质量(FID)达到稳定。
    • ( τ ∗ \tau^* τ):控制温度从0过渡到 ( ε max \varepsilon_{\text{max}} εmax) 的时间点(如 ( τ ∗ = 0.9 \tau^* = 0.9 τ=0.9))。
    • ( λ CD \lambda_{\text{CD}} λCD):如CIFAR-10使用 ( λ CD = 2 × 1 0 − 4 \lambda_{\text{CD}} = 2 \times 10^{-4} λCD=2×104),CelebA使用 ( λ CD = 2 × 1 0 − 5 \lambda_{\text{CD}} = 2 \times 10^{-5} λCD=2×105)。

直观解释

双目标优化就像是“先搭框架,再精雕细琢”:

  • Phase 1 使用 ( L OT \mathcal{L}_{\text{OT}} LOT) 搭建一个粗略的能量景观,确保样本可以快速从噪声区域移动到数据流形附近,就像建造一条高速公路。
  • Phase 2 引入 ( L CD \mathcal{L}_{\text{CD}} LCD),在数据流形附近精雕细琢,形成精确的Boltzmann分布,就像在目标区域精心设计地形。
  • 温度调度 ( ε ( t ) \varepsilon(t) ε(t)) 起到“软切换”的作用,早期强调确定性流(OT),后期引入随机性(EBM),实现平滑过渡。

在框架中的作用

双目标策略是Energy Matching框架的核心创新之一:

  • 稳定性:通过分阶段训练(先 ( L OT \mathcal{L}_{\text{OT}} LOT),后联合优化),避免了传统EBM训练中的不稳定性和模式崩塌。
  • 高效性:( L OT \mathcal{L}_{\text{OT}} LOT) 提供了高质量的初始样本,减少了朗之万采样所需的步数。
  • 灵活性:温度调度和 ( λ CD \lambda_{\text{CD}} λCD) 允许根据数据集调整流和EBM的平衡,适应不同复杂度的生成任务。

例子:CIFAR-10生成

以CIFAR-10数据集为例,说明训练目标的应用:

  1. Phase 1

    • 采样128个数据样本和128个高斯噪声样本,计算OT映射 ( T T T )。
    • 优化 ( L OT \mathcal{L}_{\text{OT}} LOT),使 ( ∇ x V θ ( x t ( i ) ) \nabla_x V_\theta(x_t^{(i)}) xVθ(xt(i))) 匹配目标速度,训练200k次迭代。
    • 结果:形成平滑的传输路径,样本从噪声快速接近CIFAR-10图像的流形。
  2. Phase 2

    • 继续优化 ( L OT \mathcal{L}_{\text{OT}} LOT),同时引入 ( L CD \mathcal{L}_{\text{CD}} LCD)(( λ CD = 2 × 1 0 − 4 \lambda_{\text{CD}} = 2 \times 10^{-4} λCD=2×104))。
    • 使用朗之万动力学生成负样本(200步,初始化为50%数据+50%噪声)。
    • 温度从0逐渐增加到 ( ε max = 0.01 \varepsilon_{\text{max}} = 0.01 εmax=0.01),训练25k次迭代。
    • 结果:形成Boltzmann分布,FID达到3.97,显著优于传统EBM的8.61。

总结

Energy Matching的训练目标通过 ( L OT \mathcal{L}_{\text{OT}} LOT) 和 ( L CD \mathcal{L}_{\text{CD}} LCD) 的协同作用,实现了从噪声到数据的平滑过渡和精确的密度建模:

  • ( L OT \mathcal{L}_{\text{OT}} LOT) 提供了一个高效的确定性流,将噪声样本快速引导到数据流形,奠定了稳定的训练基础。
  • ( L CD \mathcal{L}_{\text{CD}} LCD) 精细调整能量函数,形成Boltzmann分布,捕捉数据的局部密度结构。
  • 双目标优化 通过温度调度和分阶段训练,平衡了OT和EBM的优势,确保了训练的稳定性和生成质量。

通过这些训练目标,Energy Matching不仅在CIFAR-10等数据集上取得了优异的生成性能(FID 3.97),还为逆问题求解和局部内在维度估计提供了灵活的框架。这一方法展示了静态、无旋生成模型的潜力,为未来的生成模型研究开辟了新的方向。

代码实现

为了复现《Energy Matching: Unifying Flow Matching and Energy-Based Models for Generative Modeling》一文中在CIFAR-10数据集上的实验,我们需要实现Energy Matching的训练代码,遵循论文中描述的模型结构和超参数。以下是一个完整的Python代码实现,使用PyTorch框架,基于论文中的描述(特别是2.1节和Appendix C)。代码包括模型定义、训练流程(Phase 1 和 Phase 2)、以及必要的工具函数。


代码概述

模型结构

根据论文Appendix C和Figure 5,模型结构为:

  • UNet:与[Tong et al., 2023]相同的UNet架构,参数量约37M,输入为3×32×32的CIFAR-10图像。
  • Vision Transformer (ViT):一个8层ViT头(包括PatchEmbed),参数量约12M,输出标量势函数 ( V θ ( x ) V_\theta(x) Vθ(x) )。
  • 总参数量约49M。

超参数

根据Appendix C(CIFAR-10部分):

  • 采样时间:( τ s = 3.0 \tau_s = 3.0 τs=3.0)
  • 温度切换点:( τ ∗ = 0.9 \tau^* = 0.9 τ=0.9)
  • 时间步长:( Δ t = 0.01 \Delta t = 0.01 Δt=0.01)
  • 朗之万采样步数:( M Langevin = 200 M_{\text{Langevin}} = 200 MLangevin=200)
  • 训练迭代:Phase 1为200k次,Phase 2为25k次
  • 批大小:128
  • 学习率:( 8 × 1 0 − 4 8 \times 10^{-4} 8×104)
  • 最大温度:( ε max = 0.01 \varepsilon_{\text{max}} = 0.01 εmax=0.01)
  • 对比散度权重:( λ CD = 2 × 1 0 − 4 \lambda_{\text{CD}} = 2 \times 10^{-4} λCD=2×104)

训练流程

  • Phase 1(Algorithm 1):仅优化 ( L OT \mathcal{L}_{\text{OT}} LOT),预训练模型以构建从噪声到数据的流。
  • Phase 2(Algorithm 2):联合优化 ( L OT + λ CD L CD \mathcal{L}_{\text{OT}} + \lambda_{\text{CD}} \mathcal{L}_{\text{CD}} LOT+λCDLCD),引入朗之万采样以形成Boltzmann分布。

依赖

  • PyTorch:用于模型定义和训练
  • POT(Python Optimal Transport):用于计算最优传输映射
  • torchvision:加载CIFAR-10数据集
  • einops:处理张量操作
  • timm:提供ViT实现

完整代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import ot  # Python Optimal Transport (POT)
import numpy as np
from einops import rearrange
from timm import create_model
import torch.nn.functional as F

# 超参数(根据Appendix C)
BATCH_SIZE = 128
LEARNING_RATE = 8e-4
TAU_S = 3.0
TAU_STAR = 0.9
DELTA_T = 0.01
EPSILON_MAX = 0.01
LAMBDA_CD = 2e-4
PHASE1_ITERS = 200000
PHASE2_ITERS = 25000
LANGEVIN_STEPS = 200
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# UNet模型(简化为示例,实际应使用[Tong et al., 2023]的UNet)
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=64):
        super(UNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, padding=1, stride=2),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.Conv2d(64, out_channels, 3, padding=1),
        )
    
    def forward(self, x):
        enc = self.encoder(x)
        dec = self.decoder(enc)
        return dec

# Energy Matching模型(UNet + ViT)
class EnergyMatchingModel(nn.Module):
    def __init__(self):
        super(EnergyMatchingModel, self).__init__()
        self.unet = UNet(in_channels=3, out_channels=64)
        # ViT头(使用timm提供的ViT,简化为small模型)
        self.vit = create_model('vit_small_patch16_224', pretrained=False, num_classes=1)
        self.patch_embed = nn.Conv2d(64, 384, kernel_size=16, stride=16)  # 适配UNet输出
    
    def forward(self, x):
        unet_out = self.unet(x)  # [B, 64, 32, 32]
        vit_in = self.patch_embed(unet_out)  # [B, 384, 2, 2]
        vit_in = rearrange(vit_in, 'b c h w -> b (h w) c')  # [B, 4, 384]
        vit_out = self.vit.forward_features(vit_in)  # [B, 1]
        return vit_out.squeeze(-1)  # [B]

# 温度调度
def epsilon_schedule(t, tau_star=TAU_STAR, epsilon_max=EPSILON_MAX):
    if t < tau_star:
        return 0.0
    elif t <= 1.0:
        return epsilon_max * (t - tau_star) / (1.0 - tau_star)
    else:
        return epsilon_max

# OT损失
def compute_ot_loss(model, x_data, x_noise, t):
    B = x_data.size(0)
    # 计算OT映射
    cost_matrix = torch.cdist(x_data.view(B, -1), x_noise.view(B, -1)) ** 2
    a, b = torch.ones(B, device=DEVICE) / B, torch.ones(B, device=DEVICE) / B
    transport_plan = ot.emd(a, b, cost_matrix.detach().cpu().numpy())
    transport_plan = torch.tensor(transport_plan, device=DEVICE, dtype=torch.float32)
    
    # 找到配对
    indices = torch.argmax(transport_plan, dim=1)
    x_mapped = x_noise[indices]
    
    # 插值
    x_t = (1 - t) * x_mapped + t * x_data
    
    # 计算目标速度
    target_velocity = x_data - x_mapped
    
    # 计算模型梯度
    x_t = x_t.requires_grad_(True)
    v_theta = model(x_t)
    grad_v = torch.autograd.grad(v_theta.sum(), x_t, create_graph=True)[0]
    
    # OT损失
    loss = torch.mean((grad_v + target_velocity) ** 2)
    return loss

# 对比散度损失
def compute_cd_loss(model, x_data, x_noise, epsilon):
    # 正样本能量
    pos_energy = model(x_data).mean() / EPSILON_MAX
    
    # 负样本采样(朗之万动力学)
    x_neg = x_noise.clone().requires_grad_(False)
    for _ in range(LANGEVIN_STEPS):
        x_neg = x_neg.requires_grad_(True)
        v_theta = model(x_neg)
        grad_v = torch.autograd.grad(v_theta.sum(), x_neg, create_graph=True)[0]
        noise = torch.randn_like(x_neg) * (2 * DELTA_T * epsilon) ** 0.5
        x_neg = x_neg - DELTA_T * grad_v + noise
        x_neg = x_neg.detach()
    
    # 负样本能量
    neg_energy = model(x_neg).mean() / EPSILON_MAX
    
    # 对比散度损失
    loss = pos_energy - neg_energy
    return loss

# 数据加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

# 初始化模型和优化器
model = EnergyMatchingModel().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Phase 1: 预训练(仅优化L_OT)
print("Starting Phase 1 Training...")
model.train()
for iteration in range(PHASE1_ITERS):
    x_data, _ = next(iter(train_loader))
    x_data = x_data.to(DEVICE)
    x_noise = torch.randn_like(x_data).to(DEVICE)
    t = torch.rand(1, device=DEVICE).item()
    
    optimizer.zero_grad()
    loss_ot = compute_ot_loss(model, x_data, x_noise, t)
    loss_ot.backward()
    optimizer.step()
    
    if (iteration + 1) % 1000 == 0:
        print(f"Iteration {iteration + 1}/{PHASE1_ITERS}, L_OT: {loss_ot.item():.4f}")

# Phase 2: 主训练(联合优化L_OT + L_CD)
print("Starting Phase 2 Training...")
for iteration in range(PHASE2_ITERS):
    x_data, _ = next(iter(train_loader))
    x_data = x_data.to(DEVICE)
    x_noise = torch.randn_like(x_data).to(DEVICE)
    t = torch.rand(1, device=DEVICE).item()
    epsilon = epsilon_schedule(t)
    
    optimizer.zero_grad()
    loss_ot = compute_ot_loss(model, x_data, x_noise, t)
    loss_cd = compute_cd_loss(model, x_data, x_noise, epsilon)
    loss = loss_ot + LAMBDA_CD * loss_cd
    loss.backward()
    optimizer.step()
    
    if (iteration + 1) % 1000 == 0:
        print(f"Iteration {iteration + 1}/{PHASE2_ITERS}, L_OT: {loss_ot.item():.4f}, L_CD: {loss_cd.item():.4f}, Total Loss: {loss.item():.4f}")

# 保存模型
torch.save(model.state_dict(), "energy_matching_cifar10.pth")
print("Training completed and model saved.")

代码说明

1. 模型结构

  • UNet:为简化示例,使用了一个小型UNet(实际应用应使用[Tong et al., 2023]的完整UNet架构)。输入为3×32×32的CIFAR-10图像,输出为64通道特征图。
  • ViT:使用timm库的vit_small_patch16_224作为基础,调整输入为UNet的输出(64×32×32)。通过PatchEmbed将特征图转换为序列,ViT输出标量 ( V θ ( x ) V_\theta(x) Vθ(x) )。
  • 实际实现中,应确保UNet参数量约37M,ViT约12M,可通过调整层数或通道数实现。

2. 训练目标

  • ( L OT \mathcal{L}_{\text{OT}} LOT)
    • 使用POT库的ot.emd计算最优传输计划,生成配对。
    • 插值点 ( x t x_t xt ) 沿测地线计算,目标速度为 ( x data − T ( x data ) x_{\text{data}} - T(x_{\text{data}}) xdataT(xdata) )。
    • 损失通过均方误差优化模型梯度与目标速度的匹配。
  • ( L CD \mathcal{L}_{\text{CD}} LCD)
    • 使用朗之万动力学生成负样本,步数为200。
    • 负样本初始化为噪声(简化起见,未实现50%数据+50%噪声的混合初始化,可通过添加条件实现)。
    • 对比散度损失计算正负样本的能量差。

3. 训练流程

  • Phase 1:200k次迭代,仅优化 ( L OT \mathcal{L}_{\text{OT}} LOT),构建流路径。
  • Phase 2:25k次迭代,联合优化 ( L OT + λ CD L CD \mathcal{L}_{\text{OT}} + \lambda_{\text{CD}} \mathcal{L}_{\text{CD}} LOT+λCDLCD),温度 ( ε ( t ) \varepsilon(t) ε(t)) 按线性调度增加。
  • 每1000次迭代打印损失,便于监控训练进展。

4. 数据加载

  • 使用torchvision加载CIFAR-10数据集,应用标准归一化(均值0.5,标准差0.5)。
  • 批大小为128,使用4个工作线程加速数据加载。

5. 优化器

  • 使用Adam优化器,学习率为 ( 8 × 1 0 − 4 8 \times 10^{-4} 8×104 )。
  • 梯度通过PyTorch的autograd自动计算。

运行环境

  • 硬件:论文使用4×A100 GPU,示例代码可在单GPU(如RTX3090)或多GPU上运行。
  • 依赖安装
    pip install torch torchvision torchaudio
    pip install POT
    pip install einops timm
    
  • 数据集:CIFAR-10将自动下载到./data目录。

注意事项

  1. UNet简化:示例中的UNet为简化版,实际应参考[Tong et al., 2023]的完整实现(如https://github.com/alexandtong/OT-CFM)。可替换UNet类为完整架构。
  2. 负样本初始化:论文建议负样本50%从数据、50%从噪声初始化,示例中仅使用噪声初始化。可修改compute_cd_loss添加混合初始化。
  3. 计算资源:200k+25k次迭代需要数天训练时间,建议使用多GPU加速。
  4. FID评估:代码未包含FID计算,可使用torch-fidelity库评估生成质量(目标FID约3.97)。
  5. OT求解器:POT库的ot.emd适用于小批量数据,大规模实验可能需优化OT计算效率。

扩展

  • 生成样本:训练完成后,可添加采样代码,使用朗之万动力学从 ( ρ eq ( x ) ∝ exp ⁡ ( − V θ ( x ) ε max ) \rho_{\text{eq}}(x) \propto \exp\left(-\frac{V_\theta(x)}{\varepsilon_{\text{max}}}\right) ρeq(x)exp(εmaxVθ(x))) 生成样本。
  • 逆问题:参考Algorithm 3,可扩展代码支持带交互能量的逆问题求解。
  • LID估计:参考3.3节,可添加Hessian计算代码估计局部内在维度。

通过运行上述代码,可以在CIFAR-10上复现Energy Matching的训练过程。

后记

2025年4月17日于上海,在grok 3大模型辅助下完成。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值