Rectified Flow(一):从直线路径到高效分布变换的深度学习新视角

Rectified Flow:从直线路径到高效分布变换的深度学习新视角

引言

在深度学习领域,生成模型和分布变换任务(如图像生成、图像到图像翻译、域适应等)一直是研究的热点。传统的生成模型(如GAN和VAE)以及近年流行的扩散模型(Diffusion Models)在这些任务中取得了显著成果,但也面临着各自的挑战:GAN训练不稳定,扩散模型推理速度慢。2022年,来自德克萨斯大学奥斯汀分校的研究团队(Xingchao Liu, Chengyue Gong, Qiang Liu)提出了一种新颖的方法——Rectified Flow,通过普通微分方程(ODE)以直线路径为核心思想,统一解决了生成建模和域迁移问题。本文将深入探讨Rectified Flow的核心机制、数学公式及其在深度学习中的洞见。

下文中图片来自于原论文:https://arxiv.org/pdf/2209.03003


Rectified Flow的核心思想

Rectified Flow的核心在于通过学习一个神经ODE模型,将两个经验分布 π 0 \pi_0 π0 π 1 \pi_1 π1 之间的变换路径尽可能调整为直线路径。为什么选择直线路径?直线路径是两点间的最短路径,具有理论上的优越性(最优运输的几何直观性)和计算上的优势(可以精确模拟,不需过多时间离散化)。这使得Rectified Flow在推理阶段能够以极少的步骤(甚至一步)生成高质量结果,弥补了传统连续时间模型(如扩散模型)推理成本高的缺陷。

具体而言,Rectified Flow的目标是从 π 0 \pi_0 π0 π 1 \pi_1 π1 构建一个运输映射 T T T,使得 Z 0 ∼ π 0 Z_0 \sim \pi_0 Z0π0 时, Z 1 = T ( Z 0 ) ∼ π 1 Z_1 = T(Z_0) \sim \pi_1 Z1=T(Z0)π1。它通过以下步骤实现:

  1. 初始耦合:从 π 0 \pi_0 π0 π 1 \pi_1 π1 中抽样一对 ( X 0 , X 1 ) (X_0, X_1) (X0,X1),通常是独立的(即 ( X 0 , X 1 ) ∼ π 0 × π 1 (X_0, X_1) \sim \pi_0 \times \pi_1 (X0,X1)π0×π1)。
  2. 直线插值:定义线性插值路径 X t = t X 1 + ( 1 − t ) X 0 X_t = t X_1 + (1-t)X_0 Xt=tX1+(1t)X0,其中 t ∈ [ 0 , 1 ] t \in [0,1] t[0,1]
  3. 学习漂移场:训练一个神经网络参数化的漂移函数 v ( z , t ) v(z,t) v(z,t),使ODE d Z t = v ( Z t , t ) d t \mathrm{d}Z_t = v(Z_t, t)\mathrm{d}t dZt=v(Zt,t)dt 的路径尽可能接近直线方向 X 1 − X 0 X_1 - X_0 X1X0
  4. 迭代校正(Reflow):通过递归应用校正过程,使路径越来越直,最终接近一步模型。

这种方法不仅适用于生成建模( π 0 \pi_0 π0 为噪声分布, π 1 \pi_1 π1 为数据分布),也适用于域迁移( π 0 \pi_0 π0 π 1 \pi_1 π1 为两个经验分布)。


数学公式与算法细节
基本定义

Rectified Flow的核心是一个ODE模型:
d Z t = v ( Z t , t ) d t , Z 0 ∼ π 0 , Z 1 ∼ π 1 \mathrm{d}Z_t = v(Z_t, t) \mathrm{d}t, \quad Z_0 \sim \pi_0, \quad Z_1 \sim \pi_1 dZt=v(Zt,t)dt,Z0π0,Z1π1
其中 v ( z , t ) v(z,t) v(z,t) 是漂移函数,表示在时间 t t t 和位置 z z z 处的速度场。目标是让 Z t Z_t Zt 的轨迹尽量接近 X 1 − X 0 X_1 - X_0 X1X0 的方向。

优化目标

为了学习 v ( z , t ) v(z,t) v(z,t),Rectified Flow通过一个简单的非线性最小二乘问题定义优化目标:
min ⁡ v ∫ 0 1 E [ ∥ ( X 1 − X 0 ) − v ( X t , t ) ∥ 2 2 ] d t , X t = t X 1 + ( 1 − t ) X 0 \min_v \int_0^1 \mathbb{E}\left[ \left\| (X_1 - X_0) - v(X_t, t) \right\|_2^2 \right] \mathrm{d}t, \quad X_t = t X_1 + (1-t) X_0 vmin01E[(X1X0)v(Xt,t)22]dt,Xt=tX1+(1t)X0
这里:

  • X t X_t Xt ( X 0 , X 1 ) (X_0, X_1) (X0,X1) 的线性插值,表示理想的直线路径。
  • ( X 1 − X 0 ) (X_1 - X_0) (X1X0) 是直线方向的速度。
  • v ( X t , t ) v(X_t, t) v(Xt,t) 是模型预测的速度,目标是让它尽量接近 ( X 1 − X 0 ) (X_1 - X_0) (X1X0)

优化过程使用随机梯度下降(SGD)或Adam等方法,基于 ( X 0 , X 1 ) (X_0, X_1) (X0,X1) 的经验样本训练神经网络 v θ ( z , t ) v_\theta(z,t) vθ(z,t)

理论最优解

如果优化问题能精确求解,最优的漂移场为:
v X ( z , t ) = E [ X 1 − X 0 ∣ X t = z ] v^X(z, t) = \mathbb{E}\left[ X_1 - X_0 \mid X_t = z \right] vX(z,t)=E[X1X0Xt=z]
这表示在时间 t t t 和位置 z z z 处,所有经过 z z z 的直线方向 ( X 1 − X 0 ) (X_1 - X_0) (X1X0) 的条件期望。直观上, v X ( z , t ) v^X(z,t) vX(z,t) 将线性插值路径“因果化”,避免了交叉(见下文非交叉性质)。

非交叉性质

ODE的解具有唯一性,因此 Z t Z_t Zt 的轨迹不会交叉。这与 X t X_t Xt 的线性插值路径(可能交叉)形成对比。Rectified Flow通过“重新布线”(rewiring)避免交叉,同时保持边际分布一致:
Law ( Z t ) = Law ( X t ) , ∀ t ∈ [ 0 , 1 ] \text{Law}(Z_t) = \text{Law}(X_t), \quad \forall t \in [0,1] Law(Zt)=Law(Xt),t[0,1]

运输成本下降

Rectified Flow的一个重要理论性质是,它将任意耦合 ( X 0 , X 1 ) (X_0, X_1) (X0,X1) 转换为确定性耦合 ( Z 0 , Z 1 ) (Z_0, Z_1) (Z0,Z1),且对于所有凸成本函数 c c c,运输成本不增加:
E [ c ( Z 1 − Z 0 ) ] ≤ E [ c ( X 1 − X 0 ) ] \mathbb{E}\left[ c(Z_1 - Z_0) \right] \leq \mathbb{E}\left[ c(X_1 - X_0) \right] E[c(Z1Z0)]E[c(X1X0)]
例如,当 c ( x ) = ∥ x ∥ c(x) = \|x\| c(x)=x 时, Z t Z_t Zt 的路径长度不会超过 X t X_t Xt 的直线长度(三角不等式)。

Reflow与路径变直

为了进一步接近直线路径,论文提出了递归校正(Reflow)过程:

  1. 从初始耦合 ( X 0 , X 1 ) (X_0, X_1) (X0,X1) 训练第一个Rectified Flow Z 1 Z^1 Z1
  2. 模拟 Z 1 Z^1 Z1 得到新耦合 ( Z 0 1 , Z 1 1 ) (Z_0^1, Z_1^1) (Z01,Z11)
  3. ( Z 0 1 , Z 1 1 ) (Z_0^1, Z_1^1) (Z01,Z11) 训练 Z 2 Z^2 Z2,重复此过程。

路径直线性用以下指标衡量:
S ( Z ) = ∫ 0 1 E [ ∥ ( Z 1 − Z 0 ) − Z ˙ t ∥ 2 ] d t S(Z) = \int_0^1 \mathbb{E}\left[ \left\| (Z_1 - Z_0) - \dot{Z}_t \right\|^2 \right] \mathrm{d}t S(Z)=01E[ (Z1Z0)Z˙t 2]dt
S ( Z ) = 0 S(Z) = 0 S(Z)=0 表示完全直线。理论上, S ( Z k ) S(Z^k) S(Zk) k k k 增加而减小:
min ⁡ k ∈ { 0 , … , K } S ( Z k ) ≤ E [ ∥ X 1 − X 0 ∥ 2 ] K \min_{k \in \{0, \dots, K\}} S(Z^k) \leq \frac{\mathbb{E}\left[ \|X_1 - X_0\|^2 \right]}{K} k{0,,K}minS(Zk)KE[X1X02]
实验表明,一次Reflow( k = 2 k=2 k=2)即可显著提高直线性。

算法伪代码

以下是训练和采样的核心算法(参考论文Algorithm 1):

# 训练
def Train(Data):
    # Data = {(x0, x1)} 从 π_0 × π_1 采样
    initialize v_θ(z, t)
    for (x0, x1) in Data:
        t ~ Uniform([0, 1])
        x_t = t * x1 + (1 - t) * x0
        loss = ||v_θ(x_t, t) - (x1 - x0)||^2.mean()
        optimize(loss)
    return v_θ

# 采样
def Sample(v_θ, x0):
    # x0 ~ π_0
    z_t = ODE_solve(v_θ, z_0 = x0, t ∈ [0, 1])
    return (z_0, z_1)

# Reflow
def Reflow(Data, K):
    coupling = Data
    for k in range(K):
        v_θ = Train(coupling)
        coupling = Sample(v_θ, {x0 from coupling})
    return coupling

实验结果与应用
  1. 图像生成

    • 在CIFAR-10上,2-Rectified Flow一步生成(Euler步长 N = 1 N=1 N=1)的FID达到4.85,优于传统一步模型(如GAN的8.91)。
    • 在高分辨率数据集(如LSUN、CelebA HQ)上,1-Rectified Flow也能生成高质量图像。
  2. 图像到图像翻译

    • 在无配对数据的情况下(如人脸到猫脸),通过调整损失函数(考虑特征映射 h ( x ) h(x) h(x)),Rectified Flow生成视觉上高质量的混合图像。
    • 2-Rectified Flow一步即可完成风格迁移。
  3. 域适应

    • 在DomainNet和OfficeHome数据集上,Rectified Flow将测试域迁移到训练域,分类准确率达到state-of-the-art(69.2% 和 41.4%)。

与其他方法的对比
  1. 与GAN的对比

    • GAN通过对抗训练学习映射,易出现模式崩塌和不稳定。Rectified Flow用简单的回归优化,避免了这些问题。
  2. 与扩散模型的对比

    • 扩散模型(如DDPM)依赖SDE,推理需要数百步。Rectified Flow是纯ODE方法,通过直线路径大幅减少推理步骤。
    • 论文证明,概率流ODE(PF-ODE)和DDIM是Rectified Flow的非线性特例,但路径非直且速度不均。

洞见与启发
  1. 直线路径的哲学意义

    • Rectified Flow揭示了分布变换的最优性可能不依赖复杂的曲线路径。直线路径不仅是计算上的捷径,也可能是数据分布间联系的本质体现。这提示我们在设计生成模型时,应优先考虑几何简单性。
  2. 因果性与确定性

    • 通过将非因果的线性插值“因果化”,Rectified Flow将随机耦合转化为确定性耦合。这为理解生成过程的因果结构提供了新视角。
  3. Reflow的潜力

    • Reflow过程类似于“自蒸馏”,通过迭代优化逼近一步模型。这种策略可能适用于其他连续时间模型(如扩散模型),以提升推理效率。
  4. 扩展性思考

    • 论文提到非线性扩展(用任意曲线替换线性插值),这为引入非欧几何(如流形上的变换)提供了可能。未来研究可探索如何结合领域知识设计 X t X_t Xt

结论

Rectified Flow以其简洁而优雅的设计,为生成建模和分布变换任务提供了一个统一的、高效的解决方案。它通过直线路径和递归校正,不仅在理论上保证了边际分布一致性和运输成本下降,还在实践中实现了快速推理和高性能。对于深度学习研究者而言,这不仅是一种新工具,更是一种新思路:简单性与效率的结合,往往能带来意想不到的突破。


分布沿直线或曲线变化的意义和边际分布(Marginal Distribution)的含义。

由于篇幅限制,具体的内容可以参考笔者的另一篇博客:Rectified Flow(二):边际分布(Marginal Distribution)的作用, 主要内容有:分布沿直线或曲线变化的意义和边际分布(Marginal Distribution)的含义。

理解Figure 2

Figure 2 是 Rectified Flow 方法的核心可视化示例,旨在展示如何从初始的线性插值路径(可能有交叉)逐步校正为更直、更符合 ODE 性质的路径。让我们一步步拆解。


在这里插入图片描述

Figure 2 的背景和目标

Rectified Flow 的目标是通过普通微分方程(ODE)从分布 π 0 \pi_0 π0 变换到分布 π 1 \pi_1 π1,并且希望变换路径尽量接近直线路径。直线路径有两大优势:

  1. 最短路径:直线是两点间的最短路径,理论上运输成本最低。
  2. 计算效率:直线路径可以用一步模拟(无需多步离散化)。

然而,初始的线性插值路径 X t = t X 1 + ( 1 − t ) X 0 X_t = t X_1 + (1 - t) X_0 Xt=tX1+(1t)X0(从 X 0 ∼ π 0 X_0 \sim \pi_0 X0π0 X 1 ∼ π 1 X_1 \sim \pi_1 X1π1)可能会出现轨迹交叉(crossing),这不符合 ODE 的非交叉性质(因为 ODE 的解是唯一的,轨迹不能交叉)。Figure 2 通过四个子图展示了 Rectified Flow 如何通过“校正”(rectification)和“重新布线”(rewiring)解决这个问题,并逐步将路径变直。


Figure 2 的四个子图

(a) Linear interpolation X t = t X 1 + ( 1 − t ) X 0 X_t = t X_1 + (1 - t) X_0 Xt=tX1+(1t)X0
  • 内容:图中展示了两个分布 π 0 \pi_0 π0(紫色点)和 π 1 \pi_1 π1(红色点),每个点对 ( X 0 , X 1 ) (X_0, X_1) (X0,X1) 之间用一条线连接,表示线性插值路径 X t = t X 1 + ( 1 − t ) X 0 X_t = t X_1 + (1 - t) X_0 Xt=tX1+(1t)X0
  • 关键点:这些路径会交叉。例如,图中两条路径(绿色和蓝色)在中间相交,形成了 X 0 → X 1 X_0 \rightarrow X_1 X0X1 X 0 ′ → X 1 ′ X_0' \rightarrow X_1' X0X1 的交叉。
  • 问题:这种交叉是非因果的(non-causal),因为线性插值需要同时知道起点 X 0 X_0 X0 和终点 X 1 X_1 X1,而 ODE 要求路径是因果的(只依赖当前状态 Z t Z_t Zt 和时间 t t t)。

直观解释
想象 π 0 \pi_0 π0 π 1 \pi_1 π1 是两个城市群,每对 ( X 0 , X 1 ) (X_0, X_1) (X0,X1) 是一条直线公路。但这些公路可能会交叉(比如在某个点 X t X_t Xt 上,绿色和蓝色路径重合但方向不同)。这会导致“交通混乱”,因为 ODE 要求每条路径在任何时刻都不能交叉,否则解就不唯一。


(b) Rectified flow Z t Z_t Zt induced by ( X 0 , X 1 ) (X_0, X_1) (X0,X1)
  • 内容:这一步展示了 Rectified Flow 对 ( X 0 , X 1 ) (X_0, X_1) (X0,X1) 的第一次校正,生成新的路径 Z t Z_t Zt,由 ODE d Z t = v ( Z t , t ) d t \mathrm{d}Z_t = v(Z_t, t) \mathrm{d}t dZt=v(Zt,t)dt 定义。
  • 关键点 Z t Z_t Zt 的轨迹在交叉点被“重新布线”(rewired)。图中可以看到,绿色和蓝色路径在交叉点 ( Z 0 , Z 1 ) (Z_0, Z_1) (Z0,Z1) 被重新配对,避免了交叉。原来的 X 0 → X 1 X_0 \rightarrow X_1 X0X1 X 0 ′ → X 1 ′ X_0' \rightarrow X_1' X0X1 变成了 Z 0 → Z 1 Z_0 \rightarrow Z_1 Z0Z1 Z 0 ′ → Z 1 ′ Z_0' \rightarrow Z_1' Z0Z1
  • 非交叉性质:ODE 的解是唯一的,因此 Z t Z_t Zt 的轨迹不能交叉。论文中提到:

    A key to understanding the method is the non-crossing property of flows: the different paths following a well-defined ODE d Z t = v ( Z t , t ) d t \mathrm{d}Z_t = v(Z_t, t) \mathrm{d}t dZt=v(Zt,t)dt, whose solution exists and is unique, cannot cross each other at any time t ∈ [ 0 , 1 ) t \in [0,1) t[0,1).

直观解释
继续用公路的比喻,Rectified Flow 在交叉点设置了“交通规则”,将原本交叉的公路重新规划为不交叉的路径(比如通过立交桥)。这使得每条路径 Z t Z_t Zt 都能独立前进,不会与其他路径冲突。

技术细节

  • v ( z , t ) v(z, t) v(z,t) 是通过优化以下目标学习的:
    min ⁡ v ∫ 0 1 E [ ∥ ( X 1 − X 0 ) − v ( X t , t ) ∥ 2 2 ] d t \min_v \int_0^1 \mathbb{E}\left[ \| (X_1 - X_0) - v(X_t, t) \|_2^2 \right] \mathrm{d}t vmin01E[(X1X0)v(Xt,t)22]dt
  • 最优解为 v X ( z , t ) = E [ X 1 − X 0 ∣ X t = z ] v^X(z, t) = \mathbb{E}[X_1 - X_0 \mid X_t = z] vX(z,t)=E[X1X0Xt=z],表示在 X t = z X_t = z Xt=z 时,所有直线方向的条件期望。

( c) Linear interpolation Z t = t Z 1 + ( 1 − t ) Z 0 Z_t = t Z_1 + (1 - t) Z_0 Zt=tZ1+(1t)Z0
  • 内容:这一步用前一步生成的 ( Z 0 , Z 1 ) (Z_0, Z_1) (Z0,Z1) 作为新的起点和终点,再次进行线性插值,生成路径 Z t = t Z 1 + ( 1 − t ) Z 0 Z_t = t Z_1 + (1 - t) Z_0 Zt=tZ1+(1t)Z0
  • 关键点:由于 ( Z 0 , Z 1 ) (Z_0, Z_1) (Z0,Z1) 已经是校正后的配对,新的线性插值路径仍然可能有交叉,但相比于 ( X 0 , X 1 ) (X_0, X_1) (X0,X1),交叉的程度已经减少。

直观解释
这一步相当于用新的城市对 ( Z 0 , Z 1 ) (Z_0, Z_1) (Z0,Z1) 重新规划公路。新的公路仍然是直线,但因为 ( Z 0 , Z 1 ) (Z_0, Z_1) (Z0,Z1) 已经经过一次校正,交叉问题有所缓解。


(d) Rectified flow Z t ′ Z_t' Zt induced by ( Z 0 , Z 1 ) (Z_0, Z_1) (Z0,Z1)
  • 内容:对 ( Z 0 , Z 1 ) (Z_0, Z_1) (Z0,Z1) 再次应用 Rectified Flow,生成新的路径 Z t ′ Z_t' Zt
  • 关键点:图中显示, Z t ′ Z_t' Zt 的轨迹已经非常接近直线,几乎没有交叉。路径变得更“直”(straight),这是 Reflow 过程的结果。
  • 意义:通过递归校正(Reflow),路径越来越直,最终接近一步模型(single-step model),即 Z 1 = Z 0 + v ( Z 0 , 0 ) Z_1 = Z_0 + v(Z_0, 0) Z1=Z0+v(Z0,0)

直观解释
经过第二次规划,公路几乎完全没有交叉,变成了近似直线的路径。这意味着从 Z 0 Z_0 Z0 Z 1 Z_1 Z1 的变换可以用一步完成,极大提高了计算效率。

技术细节

  • 直线性用以下指标衡量:
    S ( Z ) = ∫ 0 1 E [ ∥ ( Z 1 − Z 0 ) − Z ˙ t ∥ 2 2 ] d t S(\boldsymbol{Z}) = \int_0^1 \mathbb{E}\left[ \| (Z_1 - Z_0) - \dot{Z}_t \|_2^2 \right] \mathrm{d}t S(Z)=01E[(Z1Z0)Z˙t22]dt
  • S ( Z ) = 0 S(\boldsymbol{Z}) = 0 S(Z)=0 表示完全直线。Reflow 过程会逐步减小 S ( Z ) S(\boldsymbol{Z}) S(Z)

Figure 2 的整体含义

从 (a) 到 (d) 的演化
  • (a):初始线性插值 X t X_t Xt,路径有交叉,非因果。
  • (b):第一次校正,生成 Z t Z_t Zt,通过重新布线避免交叉。
  • ( c):用校正后的 ( Z 0 , Z 1 ) (Z_0, Z_1) (Z0,Z1) 再次线性插值,交叉减少。
  • (d):第二次校正,生成 Z t ′ Z_t' Zt,路径几乎完全直线。
核心思想

Figure 2 展示了 Rectified Flow 的核心机制:

  1. 非交叉性:ODE 路径不能交叉,Rectified Flow 通过“重新布线”解决线性插值的交叉问题。
  2. 路径变直:通过递归校正(Reflow),路径从可能交叉的复杂状态逐步变为直线状态。
  3. 效率提升:直线路径可以一步模拟,减少推理成本。
论文中的解释

论文在 Figure 2 的标题中提到:

The trajectories are “rewired” at the intersection points to avoid the crossing. (…) The rectified flow induced from ( Z 0 , Z 1 ) (Z_0, Z_1) (Z0,Z1), which follows straight paths.

翻译为中文:

轨迹在交叉点被“重新布线”以避免交叉。(…) 由 ( Z 0 , Z 1 ) (Z_0, Z_1) (Z0,Z1) 诱导的校正流,遵循直线路径。

这强调了 Rectified Flow 的两个关键性质:避免交叉和路径变直。


结合上下文的意义

为什么要避免交叉?

在概率分布变换中,交叉会导致路径的非确定性(non-determinism)。例如,假设 X t X_t Xt 在某个点 z z z 处交叉,绿色路径和蓝色路径在 t = 0.5 t = 0.5 t=0.5 时都经过 z z z,但方向不同(比如一个向右,一个向左)。这会导致 ODE 的解不唯一(因为 v ( z , t ) v(z, t) v(z,t) 无法确定方向),违反了 ODE 的基本性质。Rectified Flow 通过 v ( z , t ) v(z, t) v(z,t) 的设计(条件期望)确保路径不交叉。

为什么追求直线?
  • 理论上:直线是最短路径,运输成本最低(论文中提到降低了凸运输成本)。
  • 实践上:直线路径可以用一步 Euler 步模拟( Z 1 = Z 0 + v ( Z 0 , 0 ) Z_1 = Z_0 + v(Z_0, 0) Z1=Z0+v(Z0,0)),相比于扩散模型(需要数百步)效率更高。
实际应用
  • 在图像生成中,直线路径意味着从噪声到图像的变换更直接,中间状态更可控。
  • 在域迁移中,直线路径使得从源域到目标域的风格转换更平滑。

总结

Figure 2 是一个简化的二维玩具示例,展示了 Rectified Flow 如何从初始的交叉路径(线性插值)逐步校正为不交叉、近似直线的路径:

  • (a) 展示了问题:线性插值会导致交叉。
  • (b) 展示了第一次校正:通过重新布线避免交叉。
  • (c ) 展示了中间步骤:用校正后的点对重新插值。
  • (d) 展示了最终结果:路径几乎完全直线。

这个过程体现了 Rectified Flow 的核心思想:通过递归校正,将复杂的变换路径简化为直线路径,从而实现高效的分布变换。

Reflow的过程

详细介绍 Rectified Flow 论文中提出的 Reflow 过程,包括它的目标、具体步骤、数学原理、算法实现,以及在实际应用中的作用。Reflow 是 Rectified Flow 的核心技术之一,通过递归校正(recursive rectification)将初始的变换路径逐步调整为更直的路径,从而实现高效的一步生成或分布变换。


Reflow 的目标

Rectified Flow 的目标是从分布 π 0 \pi_0 π0 变换到分布 π 1 \pi_1 π1,并希望变换路径尽量接近直线路径。直线路径的优势在于:

  1. 最优运输:直线是两点间的最短路径,理论上运输成本最低。
  2. 高效推理:直线路径可以用一步模拟(single-step sampling),相比于传统连续时间模型(如扩散模型需要数百步)更高效。

然而,初始的 Rectified Flow(称为 1-Rectified Flow)生成的路径 Z t Z_t Zt 可能不是完全直线,因为:

  • 初始耦合 ( X 0 , X 1 ) ∼ π 0 × π 1 (X_0, X_1) \sim \pi_0 \times \pi_1 (X0,X1)π0×π1 是独立采样的,路径可能交叉。
  • 神经网络 v θ ( z , t ) v_\theta(z, t) vθ(z,t) 是对最优漂移场 v X ( z , t ) v^X(z, t) vX(z,t) 的近似,可能无法完全捕捉直线方向 X 1 − X 0 X_1 - X_0 X1X0

Reflow 的目标是通过递归应用 Rectified Flow,将路径 Z t Z_t Zt 逐步“变直”,最终接近一步模型(即 Z 1 = Z 0 + v ( Z 0 , 0 ) Z_1 = Z_0 + v(Z_0, 0) Z1=Z0+v(Z0,0)),同时保持边际分布 Law ( Z t ) \text{Law}(Z_t) Law(Zt) 不变。


Reflow 的核心思想

Reflow 的核心思想是迭代校正

  1. 从初始耦合 ( X 0 , X 1 ) (X_0, X_1) (X0,X1) 训练第一个 Rectified Flow,生成路径 Z t 1 Z^1_t Zt1 和新耦合 ( Z 0 1 , Z 1 1 ) (Z_0^1, Z_1^1) (Z01,Z11)
  2. ( Z 0 1 , Z 1 1 ) (Z_0^1, Z_1^1) (Z01,Z11) 作为新的训练数据,训练第二个 Rectified Flow,生成 Z t 2 Z^2_t Zt2 ( Z 0 2 , Z 1 2 ) (Z_0^2, Z_1^2) (Z02,Z12)
  3. 重复此过程 K K K 次,直到路径足够直。

每次校正(Reflow 一步)都会使路径更接近直线,减少路径的“弯曲度”(quantified by the straightness metric S ( Z ) S(\boldsymbol{Z}) S(Z))。

直线性指标

论文中定义了路径的直线性指标:
S ( Z ) = ∫ 0 1 E [ ∥ ( Z 1 − Z 0 ) − Z ˙ t ∥ 2 2 ] d t S(\boldsymbol{Z}) = \int_0^1 \mathbb{E}\left[ \| (Z_1 - Z_0) - \dot{Z}_t \|_2^2 \right] \mathrm{d}t S(Z)=01E[(Z1Z0)Z˙t22]dt

  • ( Z 1 − Z 0 ) (Z_1 - Z_0) (Z1Z0) 是理想的直线速度方向。
  • Z ˙ t = v ( Z t , t ) \dot{Z}_t = v(Z_t, t) Z˙t=v(Zt,t) 是实际的速度。
  • S ( Z ) = 0 S(\boldsymbol{Z}) = 0 S(Z)=0 表示路径完全直线。

Reflow 的目标是逐步减小 S ( Z ) S(\boldsymbol{Z}) S(Z),使路径越来越直。


Reflow 的具体步骤

1. 初始耦合

Reflow 从初始的样本对 ( X 0 , X 1 ) (X_0, X_1) (X0,X1) 开始,其中:

  • X 0 ∼ π 0 X_0 \sim \pi_0 X0π0 X 1 ∼ π 1 X_1 \sim \pi_1 X1π1
  • 初始耦合 ( X 0 , X 1 ) ∼ π 0 × π 1 (X_0, X_1) \sim \pi_0 \times \pi_1 (X0,X1)π0×π1,即 X 0 X_0 X0 X 1 X_1 X1 是独立采样的。
2. 第一次 Rectified Flow(1-Rectified Flow)
  • 训练:用 ( X 0 , X 1 ) (X_0, X_1) (X0,X1) 训练第一个漂移场 v θ 1 ( z , t ) v_\theta^1(z, t) vθ1(z,t),优化目标为:
    min ⁡ v θ 1 ∫ 0 1 E [ ∥ ( X 1 − X 0 ) − v θ 1 ( X t , t ) ∥ 2 2 ] d t , X t = t X 1 + ( 1 − t ) X 0 \min_{v_\theta^1} \int_0^1 \mathbb{E}\left[ \| (X_1 - X_0) - v_\theta^1(X_t, t) \|_2^2 \right] \mathrm{d}t, \quad X_t = t X_1 + (1 - t) X_0 vθ1min01E[(X1X0)vθ1(Xt,t)22]dt,Xt=tX1+(1t)X0
  • 采样:用训练好的 v θ 1 ( z , t ) v_\theta^1(z, t) vθ1(z,t) 解决 ODE:
    d Z t 1 = v θ 1 ( Z t 1 , t ) d t , Z 0 1 ∼ π 0 \mathrm{d}Z_t^1 = v_\theta^1(Z_t^1, t) \mathrm{d}t, \quad Z_0^1 \sim \pi_0 dZt1=vθ1(Zt1,t)dt,Z01π0
    得到 Z t 1 Z_t^1 Zt1 的轨迹,并生成新耦合 ( Z 0 1 , Z 1 1 ) (Z_0^1, Z_1^1) (Z01,Z11),其中 Z 0 1 ∼ π 0 Z_0^1 \sim \pi_0 Z01π0 Z 1 1 ∼ π 1 Z_1^1 \sim \pi_1 Z11π1
3. 第二次 Rectified Flow(2-Rectified Flow)
  • 训练:用 ( Z 0 1 , Z 1 1 ) (Z_0^1, Z_1^1) (Z01,Z11) 作为新的训练数据,训练第二个漂移场 v θ 2 ( z , t ) v_\theta^2(z, t) vθ2(z,t),优化目标为:
    min ⁡ v θ 2 ∫ 0 1 E [ ∥ ( Z 1 1 − Z 0 1 ) − v θ 2 ( Z t 1 , t ) ∥ 2 2 ] d t , Z t 1 = t Z 1 1 + ( 1 − t ) Z 0 1 \min_{v_\theta^2} \int_0^1 \mathbb{E}\left[ \| (Z_1^1 - Z_0^1) - v_\theta^2(Z_t^1, t) \|_2^2 \right] \mathrm{d}t, \quad Z_t^1 = t Z_1^1 + (1 - t) Z_0^1 vθ2min01E[(Z11Z01)vθ2(Zt1,t)22]dt,Zt1=tZ11+(1t)Z01
  • 采样:用 v θ 2 ( z , t ) v_\theta^2(z, t) vθ2(z,t) 解决 ODE:
    d Z t 2 = v θ 2 ( Z t 2 , t ) d t , Z 0 2 ∼ π 0 \mathrm{d}Z_t^2 = v_\theta^2(Z_t^2, t) \mathrm{d}t, \quad Z_0^2 \sim \pi_0 dZt2=vθ2(Zt2,t)dt,Z02π0
    得到 ( Z 0 2 , Z 1 2 ) (Z_0^2, Z_1^2) (Z02,Z12)
4. 重复 K K K
  • 重复上述过程 K K K 次,每次用前一步的 ( Z 0 k , Z 1 k ) (Z_0^k, Z_1^k) (Z0k,Z1k) 训练 v θ k + 1 ( z , t ) v_\theta^{k+1}(z, t) vθk+1(z,t),生成 ( Z 0 k + 1 , Z 1 k + 1 ) (Z_0^{k+1}, Z_1^{k+1}) (Z0k+1,Z1k+1)
  • 理论上, S ( Z k ) S(\boldsymbol{Z}^k) S(Zk) k k k 增加而减小,路径越来越直。
5. 最终结果
  • 经过 K K K 次 Reflow,路径接近完全直线,变换可以用一步完成:
    Z 1 K ≈ Z 0 K + v θ K ( Z 0 K , 0 ) Z_1^K \approx Z_0^K + v_\theta^K(Z_0^K, 0) Z1KZ0K+vθK(Z0K,0)

数学原理:为什么 Reflow 能让路径变直?

理论保证

论文中的 Theorem 3.5 提供了 Reflow 的理论支持:

Let Z k \boldsymbol{Z}^k Zk be the k k k-th rectified flow induced by X 0 = X \boldsymbol{X}^0 = \boldsymbol{X} X0=X. Then for any K ≥ 1 K \geq 1 K1, we have:
min ⁡ k ∈ { 0 , … , K } S ( Z k ) ≤ E [ ∥ X 1 − X 0 ∥ 2 ] K \min_{k \in \{0, \dots, K\}} S(\boldsymbol{Z}^k) \leq \frac{\mathbb{E}\left[ \|X_1 - X_0\|^2 \right]}{K} k{0,,K}minS(Zk)KE[X1X02]

  • 含义 S ( Z k ) S(\boldsymbol{Z}^k) S(Zk)(直线性指标)会随着 Reflow 次数 K K K 增加而减小,下降速度接近 1 / K 1/K 1/K
  • 直观解释:每次 Reflow 都让路径更接近直线,最终 S ( Z k ) → 0 S(\boldsymbol{Z}^k) \to 0 S(Zk)0,即路径完全直。
非交叉性质

每次 Reflow 生成的 Z t k Z_t^k Ztk 由 ODE 定义,轨迹不会交叉(因为 ODE 的解唯一)。这使得 ( Z 0 k , Z 1 k ) (Z_0^k, Z_1^k) (Z0k,Z1k) 的耦合比前一步更“确定性”(deterministic),路径更接近直线。

边际分布保持

论文证明了 Law ( Z t k ) = Law ( X t ) \text{Law}(Z_t^k) = \text{Law}(X_t) Law(Ztk)=Law(Xt)(定理3.3),即每次 Reflow 不会改变边际分布,只是调整了路径的几何形状。


算法实现

论文中的 Algorithm 1 提供了 Reflow 的伪代码,将其整理为更详细的实现:

# 训练 Rectified Flow
def train_rectified_flow(coupling, num_epochs):
    # coupling: 样本对 {(x0, x1)}
    v_theta = initialize_neural_network()  # 初始化漂移场 v_θ(z, t)
    for epoch in range(num_epochs):
        for (x0, x1) in coupling:
            t = random.uniform(0, 1)  # 随机采样 t
            x_t = t * x1 + (1 - t) * x0  # 线性插值
            target_velocity = x1 - x0  # 目标速度
            predicted_velocity = v_theta(x_t, t)  # 预测速度
            loss = ((predicted_velocity - target_velocity) ** 2).mean()  # 均方误差
            optimize(loss)  # 用 SGD 或 Adam 优化
    return v_theta

# 用 Rectified Flow 采样
def sample_rectified_flow(v_theta, x0, num_steps=100):
    # x0: 初始样本
    z_t = x0
    dt = 1.0 / num_steps
    for step in range(num_steps):
        t = step * dt
        z_t = z_t + v_theta(z_t, t) * dt  # Euler 积分
    return z_t

# Reflow 过程
def reflow(initial_coupling, K, num_epochs_per_flow):
    # initial_coupling: 初始样本对 {(x0, x1)} ~ π_0 × π_1
    coupling = initial_coupling
    for k in range(K):
        # 训练第 k 次 Rectified Flow
        v_theta = train_rectified_flow(coupling, num_epochs_per_flow)
        # 用训练好的 v_theta 采样新耦合
        new_coupling = []
        for (x0, _) in coupling:
            z1 = sample_rectified_flow(v_theta, x0)  # 从 x0 采样 z1
            new_coupling.append((x0, z1))
        coupling = new_coupling  # 更新耦合
    return coupling, v_theta

# 示例使用
initial_coupling = [(x0, x1) for x0, x1 in zip(pi_0_samples, pi_1_samples)]  # 初始样本对
K = 2  # Reflow 次数
num_epochs_per_flow = 1000  # 每次训练的 epoch 数
final_coupling, final_v_theta = reflow(initial_coupling, K, num_epochs_per_flow)
代码说明
  • train_rectified_flow:训练一个 Rectified Flow,通过最小化预测速度与目标速度的均方误差。
  • sample_rectified_flow:用训练好的 v θ v_\theta vθ 采样新路径,生成 ( Z 0 k , Z 1 k ) (Z_0^k, Z_1^k) (Z0k,Z1k)
  • reflow:递归调用训练和采样,逐步校正路径。

Reflow 的实际效果

实验结果

论文中的实验表明:

  1. 图像生成

    • 在 CIFAR-10 上,1-Rectified Flow(无 Reflow)的 FID 为 9.23,2-Rectified Flow(一次 Reflow)降低到 4.85。
    • 一步生成(single-step)即可达到高质量结果,相比于扩散模型(需要数百步)效率更高。
  2. 域迁移

    • 在 DomainNet 数据集上,2-Rectified Flow 提升了分类准确率(从 67.8% 到 69.2%)。
    • 路径变直使得从源域到目标域的变换更平滑。
路径直线性
  • 初始路径 S ( X ) S(\boldsymbol{X}) S(X) 较高(路径弯曲)。
  • 经过 1 次 Reflow, S ( Z 1 ) S(\boldsymbol{Z}^1) S(Z1) 显著减小。
  • 经过 2 次 Reflow, S ( Z 2 ) S(\boldsymbol{Z}^2) S(Z2) 接近 0,路径几乎完全直线。

Reflow 的直观解释

类比:公路规划

想象 π 0 \pi_0 π0 π 1 \pi_1 π1 是两个城市群, ( X 0 , X 1 ) (X_0, X_1) (X0,X1) 是城市之间的公路:

  • 初始:公路是直线,但会交叉(交通混乱)。
  • 第一次 Reflow:在交叉点设置立交桥(重新布线),避免交叉。
  • 第二次 Reflow:进一步优化路线,使公路更直,减少绕行。
类比:自蒸馏

Reflow 类似于“自蒸馏”(self-distillation):

  • 每次 Reflow 用前一步的输出作为新输入,逐步提纯路径。
  • 最终路径接近一步模型,类似于“知识蒸馏”中的学生模型逼近教师模型。

Reflow 的优势与局限

优势
  1. 路径优化:逐步将路径变直,降低运输成本。
  2. 高效推理:最终模型可以用一步生成,适用于实时应用。
  3. 通用性:适用于生成建模、域迁移、图像到图像翻译等任务。
局限
  1. 计算成本:每次 Reflow 需要重新训练一个神经网络,训练成本随 K K K 增加。
  2. 超参数 K K K 的选择需要实验调整(论文中通常用 K = 2 K=2 K=2)。
  3. 初始耦合依赖:初始 ( X 0 , X 1 ) (X_0, X_1) (X0,X1) 的质量会影响 Reflow 的效果。

总结

Reflow 是 Rectified Flow 的核心技术,通过递归校正逐步将路径变直。其过程包括:

  1. 从初始耦合 ( X 0 , X 1 ) (X_0, X_1) (X0,X1) 训练第一个 Rectified Flow。
  2. 用生成的 ( Z 0 1 , Z 1 1 ) (Z_0^1, Z_1^1) (Z01,Z11) 训练第二个 Rectified Flow。
  3. 重复 K K K 次,直到路径接近直线。

Reflow 的理论支持在于直线性指标 S ( Z ) S(\boldsymbol{Z}) S(Z) 的下降,实际效果在于显著提升生成质量和推理效率。对于深度学习研究者,Reflow 提供了一个优雅的迭代优化框架,值得进一步探索和扩展。

Nonlinear Extension

详细介绍《Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow》论文中的 2.3 A Nonlinear Extension3.5 Denoising Diffusion Models and Probability Flow ODEs 两个部分,包括核心思想、数学公式和意义。这两个部分分别探讨了 Rectified Flow 的扩展性以及它与扩散模型(Diffusion Models)的关系。


2.3 A Nonlinear Extension

背景

Rectified Flow 的核心思想是基于线性插值路径 X t = t X 1 + ( 1 − t ) X 0 X_t = t X_1 + (1 - t) X_0 Xt=tX1+(1t)X0,通过学习漂移场 v ( z , t ) v(z, t) v(z,t) 使 ODE 路径 Z t Z_t Zt 尽量接近直线方向 X 1 − X 0 X_1 - X_0 X1X0。然而,线性插值是一种简单的路径选择,可能不是最优的,尤其是在复杂的高维数据分布上。2.3 A Nonlinear Extension 部分提出了一种非线性扩展,允许用更一般的路径替换线性插值,从而增加模型的灵活性。

核心思想

非线性扩展的核心是将线性插值 X t = t X 1 + ( 1 − t ) X 0 X_t = t X_1 + (1 - t) X_0 Xt=tX1+(1t)X0 替换为一个非线性路径 X t = ϕ t ( X 0 , X 1 ) X_t = \phi_t(X_0, X_1) Xt=ϕt(X0,X1),其中 ϕ t \phi_t ϕt 是一个参数化的函数,满足:

  • ϕ 0 ( X 0 , X 1 ) = X 0 \phi_0(X_0, X_1) = X_0 ϕ0(X0,X1)=X0
  • ϕ 1 ( X 0 , X 1 ) = X 1 \phi_1(X_0, X_1) = X_1 ϕ1(X0,X1)=X1

目标仍然是让 Z t Z_t Zt 的路径接近 X ˙ t \dot{X}_t X˙t,但 X ˙ t \dot{X}_t X˙t 不再是恒定的 X 1 − X 0 X_1 - X_0 X1X0,而是随时间 t t t 变化的非线性速度。

数学公式
  1. 非线性路径
    X t = ϕ t ( X 0 , X 1 ) X_t = \phi_t(X_0, X_1) Xt=ϕt(X0,X1)
    其中 ϕ t : R d × R d → R d \phi_t: \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}^d ϕt:Rd×RdRd 是一个平滑函数,满足边界条件:
    ϕ 0 ( X 0 , X 1 ) = X 0 , ϕ 1 ( X 0 , X 1 ) = X 1 \phi_0(X_0, X_1) = X_0, \quad \phi_1(X_0, X_1) = X_1 ϕ0(X0,X1)=X0,ϕ1(X0,X1)=X1

  2. 目标速度
    路径的速度是 ϕ t \phi_t ϕt t t t 的导数:
    X ˙ t = ∂ ϕ t ( X 0 , X 1 ) ∂ t \dot{X}_t = \frac{\partial \phi_t(X_0, X_1)}{\partial t} X˙t=tϕt(X0,X1)

  3. 优化目标
    漂移场 v ( z , t ) v(z, t) v(z,t) 的优化目标变为:
    min ⁡ v ∫ 0 1 E [ ∥ ∂ ϕ t ( X 0 , X 1 ) ∂ t − v ( X t , t ) ∥ 2 2 ] d t \min_v \int_0^1 \mathbb{E}\left[ \left\| \frac{\partial \phi_t(X_0, X_1)}{\partial t} - v(X_t, t) \right\|_2^2 \right] \mathrm{d}t vmin01E[ tϕt(X0,X1)v(Xt,t) 22]dt

    • 相比于线性插值时的目标 E [ ∥ ( X 1 − X 0 ) − v ( X t , t ) ∥ 2 2 ] \mathbb{E}\left[ \| (X_1 - X_0) - v(X_t, t) \|_2^2 \right] E[(X1X0)v(Xt,t)22],这里用 ∂ ϕ t ∂ t \frac{\partial \phi_t}{\partial t} tϕt 替换了 X 1 − X 0 X_1 - X_0 X1X0
  4. 最优漂移场
    与线性情况类似,最优漂移场为条件期望:
    v X ( z , t ) = E [ ∂ ϕ t ( X 0 , X 1 ) ∂ t ∣ X t = z ] v^X(z, t) = \mathbb{E}\left[ \frac{\partial \phi_t(X_0, X_1)}{\partial t} \mid X_t = z \right] vX(z,t)=E[tϕt(X0,X1)Xt=z]

解释
  1. 非线性路径的意义

    • 线性插值 X t = t X 1 + ( 1 − t ) X 0 X_t = t X_1 + (1 - t) X_0 Xt=tX1+(1t)X0 假设路径是直线,但这可能不适合所有任务。例如,在图像生成中,从噪声到图像的变换可能需要更复杂的路径(如先形成轮廓,再填充细节)。
    • 非线性路径 ϕ t ( X 0 , X 1 ) \phi_t(X_0, X_1) ϕt(X0,X1) 允许更灵活的变换,比如曲线路径或基于领域知识设计的路径。
  2. 优化目标的变化

    • 目标仍然是让 v ( z , t ) v(z, t) v(z,t) 接近路径速度,但速度 ∂ ϕ t ∂ t \frac{\partial \phi_t}{\partial t} tϕt 不再是恒定的,而是随 t t t 变化。
    • 这增加了模型的表达能力,但也增加了计算复杂性。
  3. 实际实现

    • ϕ t \phi_t ϕt 可以用神经网络参数化,例如 ϕ t ( X 0 , X 1 ) = f θ ( X 0 , X 1 , t ) \phi_t(X_0, X_1) = f_\theta(X_0, X_1, t) ϕt(X0,X1)=fθ(X0,X1,t),其中 f θ f_\theta fθ 是一个 MLP。
    • 或者,可以用预定义的非线性函数,比如基于物理系统的轨迹。
直观解释
  • 类比:线性插值就像从城市 A 到城市 B 走直线公路。非线性扩展允许走一条曲线公路(比如绕过山脉),可能更符合实际需求。
  • 意义:非线性扩展为 Rectified Flow 提供了更大的灵活性,可以适应更复杂的分布变换任务。
实际意义
  • 在图像到图像翻译中,非线性路径可能更好地捕捉风格变化的中间状态。
  • 在生成建模中,非线性路径可能更适合模拟数据的生成过程(比如先形成低频结构,再添加高频细节)。
局限
  • 非线性路径增加了计算复杂性, ϕ t \phi_t ϕt 的设计需要额外调参。
  • 论文中没有深入实验非线性扩展,更多是理论上的探讨。

3.5 Denoising Diffusion Models and Probability Flow ODEs

背景

扩散模型(Denoising Diffusion Models, DDPM)和概率流 ODE(Probability Flow ODE, PF-ODE)是近年来生成建模领域的热门方法。Rectified Flow 是一种基于 ODE 的方法,论文在 3.5 部分探讨了它与扩散模型和 PF-ODE 的关系,指出 Rectified Flow 可以看作这些模型的非线性特例,同时具有更直、更快的路径。

核心思想
  1. 扩散模型:通过一个前向扩散过程(加噪)和逆向去噪过程(生成)实现从噪声分布到数据分布的变换。
  2. 概率流 ODE:将扩散模型的逆向过程表示为一个确定性 ODE,边际分布与扩散过程一致。
  3. Rectified Flow 的联系:Rectified Flow 是一种更一般的非线性流方法,路径更直,速度更均匀,推理效率更高。
数学公式
  1. 扩散模型的前向过程
    扩散模型定义了一个前向扩散过程(Markov 过程):
    q ( X t ∣ X 0 ) = N ( X t ; 1 − β t X 0 , β t I ) q(X_t \mid X_0) = \mathcal{N}(X_t; \sqrt{1 - \beta_t} X_0, \beta_t I) q(XtX0)=N(Xt;1βt X0,βtI)

    • X 0 ∼ π 1 X_0 \sim \pi_1 X0π1(数据分布)。
    • β t \beta_t βt 是时间 t t t 的噪声方差, β t ∈ ( 0 , 1 ) \beta_t \in (0, 1) βt(0,1)
    • X t X_t Xt 是加噪后的中间状态, X 1 ∼ N ( 0 , I ) X_1 \sim \mathcal{N}(0, I) X1N(0,I)(纯噪声)。
  2. 逆向过程
    扩散模型的逆向过程是一个去噪过程,近似为:
    p θ ( X t − 1 ∣ X t ) = N ( X t − 1 ; μ θ ( X t , t ) , Σ t ) p_\theta(X_{t-1} \mid X_t) = \mathcal{N}(X_{t-1}; \mu_\theta(X_t, t), \Sigma_t) pθ(Xt1Xt)=N(Xt1;μθ(Xt,t),Σt)

    • μ θ \mu_\theta μθ 是神经网络学习的均值,通常通过分数函数(score function)表示:
      μ θ ( X t , t ) = 1 1 − β t ( X t + β t ∇ log ⁡ p t ( X t ) ) \mu_\theta(X_t, t) = \frac{1}{\sqrt{1 - \beta_t}} \left( X_t + \beta_t \nabla \log p_t(X_t) \right) μθ(Xt,t)=1βt 1(Xt+βtlogpt(Xt))
    • ∇ log ⁡ p t ( X t ) \nabla \log p_t(X_t) logpt(Xt) 是分数函数,扩散模型通过去噪分数匹配(denoising score matching)学习。
  3. 概率流 ODE
    概率流 ODE(PF-ODE)将逆向过程表示为一个确定性 ODE:
    d X t = [ f ( t ) X t − 1 2 g ( t ) 2 ∇ log ⁡ p t ( X t ) ] d t \mathrm{d}X_t = \left[ f(t) X_t - \frac{1}{2} g(t)^2 \nabla \log p_t(X_t) \right] \mathrm{d}t dXt=[f(t)Xt21g(t)2logpt(Xt)]dt

    • f ( t ) f(t) f(t) g ( t ) g(t) g(t) 是扩散过程的漂移和扩散系数。
    • 对于 DDPM, f ( t ) = − 1 2 β t f(t) = -\frac{1}{2} \beta_t f(t)=21βt g ( t ) = β t g(t) = \sqrt{\beta_t} g(t)=βt
    • PF-ODE 的边际分布 Law ( X t ) \text{Law}(X_t) Law(Xt) 与扩散过程一致。
  4. DDIM(Denoising Diffusion Implicit Models)
    DDIM 是扩散模型的一种变体,定义了一个非 Markov 的逆向过程,ODE 形式为:
    d X t = − ϵ θ ( X t , t ) β t d t \mathrm{d}X_t = -\epsilon_\theta(X_t, t) \sqrt{\beta_t} \mathrm{d}t dXt=ϵθ(Xt,t)βt dt

    • ϵ θ ( X t , t ) \epsilon_\theta(X_t, t) ϵθ(Xt,t) 是神经网络预测的噪声,近似 − 1 − β t ∇ log ⁡ p t ( X t ) -\sqrt{1 - \beta_t} \nabla \log p_t(X_t) 1βt logpt(Xt)
  5. Rectified Flow 的形式
    Rectified Flow 的 ODE 为:
    d Z t = v ( Z t , t ) d t \mathrm{d}Z_t = v(Z_t, t) \mathrm{d}t dZt=v(Zt,t)dt
    其中 v ( z , t ) v(z, t) v(z,t) 通过优化目标学习:
    v ( z , t ) ≈ E [ X 1 − X 0 ∣ X t = z ] v(z, t) \approx \mathbb{E}\left[ X_1 - X_0 \mid X_t = z \right] v(z,t)E[X1X0Xt=z]

解释
  1. PF-ODE 和 DDIM 是 Rectified Flow 的特例

    • 论文指出,PF-ODE 和 DDIM 的 ODE 形式可以看作 Rectified Flow 的非线性扩展(2.3 节)。
    • 具体来说,PF-ODE 和 DDIM 的速度场 v ( z , t ) v(z, t) v(z,t) 是非线性的,且依赖于分数函数 ∇ log ⁡ p t ( z ) \nabla \log p_t(z) logpt(z)
    • Rectified Flow 的 v ( z , t ) v(z, t) v(z,t) 更简单,直接逼近直线速度 X 1 − X 0 X_1 - X_0 X1X0,路径更直。
  2. 路径直线性

    • 扩散模型和 PF-ODE 的路径通常是曲线的,速度不均匀(早期慢,后期快)。
    • Rectified Flow 通过 Reflow 过程使路径接近直线,速度更均匀,推理效率更高。
  3. 边际分布

    • PF-ODE 和 DDIM 的边际分布 Law ( X t ) \text{Law}(X_t) Law(Xt) 与扩散过程一致。
    • Rectified Flow 也保证 Law ( Z t ) = Law ( X t ) \text{Law}(Z_t) = \text{Law}(X_t) Law(Zt)=Law(Xt),但 X t X_t Xt 是线性插值路径,计算更简单。
直观解释
  • 类比:扩散模型就像一个“随机漫步”过程,从数据到噪声再回到数据,路径曲折且需要多步采样。PF-ODE 和 DDIM 将其简化为确定性路径,但仍然是曲线。Rectified Flow 则像一条“高速公路”,路径更直,速度更均匀,可以一步到达。
  • 意义:Rectified Flow 提供了一种更高效的替代方案,避免了扩散模型的多步采样,同时保留了生成质量。
实际意义
  • 推理效率:Rectified Flow 经过 Reflow 后可以用一步生成(single-step),而扩散模型通常需要数百步。
  • 生成质量:实验表明,2-Rectified Flow 在 CIFAR-10 上的 FID(4.85)优于一步扩散模型(FID 8.91)。
  • 理论联系:Rectified Flow 提供了一个统一的框架,将扩散模型和 PF-ODE 纳入其中,同时揭示了直线路径的优越性。
局限
  • 扩散模型在高维数据上可能更擅长捕捉复杂分布(因为分数函数更灵活)。
  • Rectified Flow 的直线路径可能限制了模型的表达能力(尽管非线性扩展可以缓解)。

两部分的联系

  1. 灵活性与效率的权衡

    • 2.3 A Nonlinear Extension 增加了 Rectified Flow 的灵活性,允许非线性路径,理论上可以包含 PF-ODE 和 DDIM 作为特例。
    • 3.5 Denoising Diffusion Models and Probability Flow ODEs 指出,Rectified Flow 的直线路径比 PF-ODE 和 DDIM 更高效,但可能牺牲一些表达能力。
  2. 路径设计

    • 非线性扩展允许更复杂的路径设计,可能更接近扩散模型的曲线路径。
    • Rectified Flow 的默认直线路径(经过 Reflow 优化)在效率上优于扩散模型。
  3. 应用场景

    • 非线性扩展适合需要复杂路径的任务(如某些图像到图像翻译)。
    • Rectified Flow 的直线路径适合高效生成和域迁移任务。

总结

2.3 A Nonlinear Extension
  • 核心思想:将线性插值 X t = t X 1 + ( 1 − t ) X 0 X_t = t X_1 + (1 - t) X_0 Xt=tX1+(1t)X0 扩展为非线性路径 X t = ϕ t ( X 0 , X 1 ) X_t = \phi_t(X_0, X_1) Xt=ϕt(X0,X1),优化目标变为逼近 ∂ ϕ t ∂ t \frac{\partial \phi_t}{\partial t} tϕt
  • 公式 min ⁡ v ∫ 0 1 E [ ∥ ∂ ϕ t ( X 0 , X 1 ) ∂ t − v ( X t , t ) ∥ 2 2 ] d t \min_v \int_0^1 \mathbb{E}\left[ \left\| \frac{\partial \phi_t(X_0, X_1)}{\partial t} - v(X_t, t) \right\|_2^2 \right] \mathrm{d}t minv01E[ tϕt(X0,X1)v(Xt,t) 22]dt
  • 意义:增加模型灵活性,适应更复杂的分布变换任务。
3.5 Denoising Diffusion Models and Probability Flow ODEs
  • 核心思想:Rectified Flow 是扩散模型和 PF-ODE 的一种非线性特例,但路径更直,速度更均匀,推理效率更高。
  • 公式:PF-ODE 的 d X t = [ f ( t ) X t − 1 2 g ( t ) 2 ∇ log ⁡ p t ( X t ) ] d t \mathrm{d}X_t = \left[ f(t) X_t - \frac{1}{2} g(t)^2 \nabla \log p_t(X_t) \right] \mathrm{d}t dXt=[f(t)Xt21g(t)2logpt(Xt)]dt vs. Rectified Flow 的 d Z t = v ( Z t , t ) d t \mathrm{d}Z_t = v(Z_t, t) \mathrm{d}t dZt=v(Zt,t)dt
  • 意义:揭示了 Rectified Flow 与扩散模型的理论联系,同时强调了直线路径的高效性。

这两个部分共同展示了 Rectified Flow 的理论深度和实践优势:非线性扩展增加了灵活性,而与扩散模型的对比突出了效率优势。对于深度学习研究者,这提供了宝贵的思路:如何在灵活性和效率之间找到平衡。

代码实现

Rectified Flow 的训练代码以及用于图像生成和图像到图像翻译的代码。由于 Rectified Flow 是一种基于 ODE 的方法,涉及神经网络训练和 ODE 求解,会使用 Python 和 PyTorch 实现代码,并确保代码可运行。代码将分为两部分:训练代码和应用代码(图像生成 + 图像到图像翻译)。还会详细介绍每个部分的作用和实现细节。


环境与假设

  • 环境:代码将在 Pyodide 环境中运行,因此避免本地文件 I/O 和网络请求。
  • :使用 PyTorch(Pyodide 支持)和 NumPy。
  • 数据:由于无法访问真实数据集,将用合成数据(2D 点分布)进行训练和测试,并在注释中说明如何替换为真实图像数据。
  • ODE 求解:使用简单的 Euler 方法求解 ODE(生产环境中可以用更高级的求解器如 torchdiffeq)。

1. Rectified Flow 训练代码

目标

训练一个 Rectified Flow 模型,学习从分布 π 0 \pi_0 π0 π 1 \pi_1 π1 的漂移场 v θ ( z , t ) v_\theta(z, t) vθ(z,t)。我们将:

  1. 生成合成数据: π 0 \pi_0 π0 π 1 \pi_1 π1 是两个 2D 点分布。
  2. 定义一个神经网络 v θ ( z , t ) v_\theta(z, t) vθ(z,t)
  3. 优化目标: min ⁡ v θ ∫ 0 1 E [ ∥ ( X 1 − X 0 ) − v θ ( X t , t ) ∥ 2 2 ] d t \min_{v_\theta} \int_0^1 \mathbb{E}\left[ \| (X_1 - X_0) - v_\theta(X_t, t) \|_2^2 \right] \mathrm{d}t minvθ01E[(X1X0)vθ(Xt,t)22]dt
  4. 支持 Reflow 过程。
代码实现
import torch
import torch.nn as nn
import numpy as np

# 1. 定义漂移场神经网络 v_θ(z, t)
class VelocityNet(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(VelocityNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim + 1, hidden_dim),  # 输入是 (z, t)
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)  # 输出是 v(z, t)
        )

    def forward(self, z, t):
        # z: (batch_size, input_dim), t: (batch_size, 1)
        t = t.view(-1, 1)
        zt = torch.cat([z, t], dim=1)
        return self.net(zt)

# 2. 生成合成数据:π_0 和 π_1
def generate_synthetic_data(n_samples=1000, dim=2):
    # π_0: 中心在 (-1, -1) 的高斯分布
    pi_0 = torch.randn(n_samples, dim) * 0.5 + torch.tensor([-1.0, -1.0])
    # π_1: 中心在 (1, 1) 的高斯分布
    pi_1 = torch.randn(n_samples, dim) * 0.5 + torch.tensor([1.0, 1.0])
    return pi_0, pi_1

# 3. 训练 Rectified Flow
def train_rectified_flow(pi_0, pi_1, input_dim=2, hidden_dim=128, num_epochs=1000, lr=1e-3, batch_size=128):
    model = VelocityNet(input_dim, hidden_dim)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    n_samples = pi_0.shape[0]
    for epoch in range(num_epochs):
        # 随机采样 batch
        indices = torch.randperm(n_samples)[:batch_size]
        x0 = pi_0[indices]
        x1 = pi_1[indices]
        
        # 随机采样时间 t
        t = torch.rand(batch_size, 1)
        # 线性插值:X_t = t * X_1 + (1 - t) * X_0
        x_t = t * x1 + (1 - t) * x0
        
        # 目标速度:X_1 - X_0
        target_velocity = x1 - x0
        
        # 预测速度:v_θ(X_t, t)
        predicted_velocity = model(x_t, t)
        
        # 损失:均方误差
        loss = ((predicted_velocity - target_velocity) ** 2).mean()
        
        # 优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (epoch + 1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
    
    return model

# 4. Reflow 过程
def reflow(pi_0, pi_1, K=2, num_epochs_per_flow=1000):
    coupling = (pi_0, pi_1)
    model = None
    for k in range(K):
        print(f"Reflow iteration {k+1}/{K}")
        # 训练 Rectified Flow
        model = train_rectified_flow(coupling[0], coupling[1], num_epochs=num_epochs_per_flow)
        
        # 用训练好的模型采样新耦合
        n_samples = pi_0.shape[0]
        z0 = coupling[0]  # 保持 π_0 不变
        z1 = torch.zeros_like(z0)
        for i in range(n_samples):
            z1[i] = sample_rectified_flow(model, z0[i])
        
        # 更新耦合
        coupling = (z0, z1)
    
    return model, coupling

# 5. 用 Rectified Flow 采样
def sample_rectified_flow(model, z0, num_steps=100):
    z_t = z0.clone()
    dt = 1.0 / num_steps
    for step in range(num_steps):
        t = torch.tensor(step * dt)
        v = model(z_t, t)
        z_t = z_t + v * dt  # Euler 积分
    return z_t

# 6. 主函数
def main():
    # 生成合成数据
    pi_0, pi_1 = generate_synthetic_data()
    
    # 训练并执行 Reflow
    print("Training Rectified Flow with Reflow...")
    model, final_coupling = reflow(pi_0, pi_1, K=2)
    
    # 打印结果
    print(">>> Final coupling (first 5 samples):")
    print("Z_0:", final_coupling[0][:5])
    print("Z_1:", final_coupling[1][:5])

if __name__ == "__main__":
    main()
代码详细介绍
  1. VelocityNet

    • 定义了一个简单的 MLP,输入是 ( z , t ) (z, t) (z,t),输出是 v θ ( z , t ) v_\theta(z, t) vθ(z,t)
    • 网络结构:输入层(dim+1,包括 z z z t t t)、两层隐藏层(ReLU 激活)、输出层(dim 维速度)。
  2. generate_synthetic_data

    • 生成了两个 2D 高斯分布: π 0 \pi_0 π0 中心在 ( − 1 , − 1 ) (-1, -1) (1,1) π 1 \pi_1 π1 中心在 ( 1 , 1 ) (1, 1) (1,1)
    • 在真实场景中,可以替换为图像数据(比如 π 0 \pi_0 π0 是噪声, π 1 \pi_1 π1 是 CIFAR-10 图像)。
  3. train_rectified_flow

    • 实现 Rectified Flow 的训练。
    • 每轮迭代:
      • 采样 batch 数据 ( X 0 , X 1 ) (X_0, X_1) (X0,X1) 和时间 t t t
      • 计算线性插值 X t = t X 1 + ( 1 − t ) X 0 X_t = t X_1 + (1 - t) X_0 Xt=tX1+(1t)X0
      • 计算目标速度 X 1 − X 0 X_1 - X_0 X1X0 和预测速度 v θ ( X t , t ) v_\theta(X_t, t) vθ(Xt,t)
      • 优化均方误差损失。
  4. reflow

    • 实现了 Reflow 过程。
    • 每次 Reflow:
      • 用当前耦合 ( Z 0 k , Z 1 k ) (Z_0^k, Z_1^k) (Z0k,Z1k) 训练一个新的 Rectified Flow。
      • 用训练好的模型采样新耦合 ( Z 0 k + 1 , Z 1 k + 1 ) (Z_0^{k+1}, Z_1^{k+1}) (Z0k+1,Z1k+1)
  5. sample_rectified_flow

    • 用 Euler 方法求解 ODE,生成从 Z 0 Z_0 Z0 Z 1 Z_1 Z1 的路径。
    • 生产环境中可以用更高级的 ODE 求解器(如 torchdiffeq)。
  6. main

    • 生成数据,执行训练和 Reflow,打印最终耦合。
如何替换为真实图像数据
  • π 0 \pi_0 π0:用高斯噪声代替(torch.randn(batch_size, channels, height, width))。
  • π 1 \pi_1 π1:用真实图像数据集(如 CIFAR-10)。需要加载数据集并预处理(归一化到 [ − 1 , 1 ] [-1, 1] [1,1])。
  • 输入维度:将 input_dim 改为图像的展平维度(例如 CIFAR-10 是 3 × 32 × 32 = 3072 3 \times 32 \times 32 = 3072 3×32×32=3072)。
  • 网络结构:可以用卷积神经网络(CNN)替换 MLP,比如 U-Net 结构。

2. 图像生成和图像到图像翻译代码

目标

使用训练好的 Rectified Flow 模型执行:

  1. 图像生成:从噪声分布 π 0 \pi_0 π0 生成图像( π 1 \pi_1 π1 是目标图像分布)。
  2. 图像到图像翻译:从源域图像 π 0 \pi_0 π0 翻译到目标域图像 π 1 \pi_1 π1
假设
  • 由于无法加载真实图像数据,继续使用 2D 点分布的合成数据。
  • 会说明如何修改代码以处理真实图像数据。
代码实现
import torch
import torch.nn as nn
import numpy as np

# 1. 定义漂移场神经网络(与训练代码相同)
class VelocityNet(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(VelocityNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim + 1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, z, t):
        t = t.view(-1, 1)
        zt = torch.cat([z, t], dim=1)
        return self.net(zt)

# 2. ODE 采样
def sample_rectified_flow(model, z0, num_steps=100):
    z_t = z0.clone()
    dt = 1.0 / num_steps
    for step in range(num_steps):
        t = torch.tensor(step * dt)
        v = model(z_t, t)
        z_t = z_t + v * dt
    return z_t

# 3. 图像生成
def image_generation(model, n_samples=5, input_dim=2):
    # 从 π_0 采样噪声(这里用高斯噪声)
    z0 = torch.randn(n_samples, input_dim) * 0.5 + torch.tensor([-1.0, -1.0])
    print(">>> Initial samples (π_0, noise):")
    print(z0)
    
    # 用 Rectified Flow 生成样本
    z1 = torch.zeros_like(z0)
    for i in range(n_samples):
        z1[i] = sample_rectified_flow(model, z0[i])
    
    print(">>> Generated samples (π_1, target):")
    print(z1)
    return z0, z1

# 4. 图像到图像翻译
def image_to_image_translation(model, source_samples, num_steps=100):
    # source_samples: 从源域 π_0 采样的样本
    translated_samples = torch.zeros_like(source_samples)
    for i in range(source_samples.shape[0]):
        translated_samples[i] = sample_rectified_flow(model, source_samples[i], num_steps)
    
    print(">>> Source samples (π_0):")
    print(source_samples)
    print(">>> Translated samples (π_1):")
    print(translated_samples)
    return translated_samples

# 5. 主函数
def main():
    # 假设已经有一个训练好的模型
    input_dim = 2
    hidden_dim = 128
    model = VelocityNet(input_dim, hidden_dim)
    # 随机初始化权重(模拟训练好的模型)
    for param in model.parameters():
        nn.init.normal_(param, 0, 0.1)
    
    # 图像生成
    print("Image Generation:")
    z0, z1 = image_generation(model, n_samples=5)
    
    # 图像到图像翻译
    print("\nImage-to-Image Translation:")
    source_samples = torch.randn(5, input_dim) * 0.5 + torch.tensor([-1.0, -1.0])  # 模拟源域样本
    translated_samples = image_to_image_translation(model, source_samples)

if __name__ == "__main__":
    main()
代码详细介绍
  1. VelocityNet

    • 与训练代码中的网络相同,用于加载训练好的模型。
  2. sample_rectified_flow

    • 实现 ODE 采样,与训练代码相同。
  3. image_generation

    • 从噪声分布 π 0 \pi_0 π0 采样初始样本 Z 0 Z_0 Z0
    • 用训练好的模型生成 Z 1 Z_1 Z1,目标是接近 π 1 \pi_1 π1
    • 打印初始和生成样本。
  4. image_to_image_translation

    • 输入源域样本( π 0 \pi_0 π0),用模型翻译到目标域( π 1 \pi_1 π1)。
    • 打印源域和翻译后的样本。
  5. main

    • 模拟一个训练好的模型(随机初始化权重)。
    • 分别调用图像生成和图像到图像翻译函数。
如何替换为真实图像数据
  • 图像生成
    • Z 0 Z_0 Z0 替换为高斯噪声(torch.randn(batch_size, channels, height, width))。
    • 模型输出 Z 1 Z_1 Z1 是生成的图像,可以用 torchvision.utils.save_image 保存(但 Pyodide 不支持,需手动修改)。
  • 图像到图像翻译
    • source_samples 替换为源域图像(比如人脸图像)。
    • 模型输出是目标域图像(比如猫脸图像)。
  • 网络结构
    • 用 CNN(如 U-Net)替换 MLP,适应图像数据。
    • 输入维度改为 (channels, height, width),可以用 torch.nn.Conv2d 处理。
实际场景中的修改
  • 数据预处理:图像需要归一化到 [ − 1 , 1 ] [-1, 1] [1,1]
  • 损失函数:对于图像到图像翻译,可以引入额外的损失(如循环一致性损失)。
  • ODE 求解:用更高级的求解器(如 RK4 或 torchdiffeq)提高精度。

运行说明

  1. 训练代码

    • 运行 rectified_flow_train.py,会生成合成数据并训练一个 Rectified Flow 模型。
    • 训练完成后,执行 Reflow 过程,生成最终耦合。
  2. 应用代码

    • 运行 rectified_flow_application.py,会用训练好的模型(这里用随机初始化的模型模拟)执行图像生成和图像到图像翻译。
    • 输出是 2D 点坐标,模拟了从 π 0 \pi_0 π0 π 1 \pi_1 π1 的变换。
  3. 输出示例(假设值):

    • 图像生成:
      >>> Initial samples (π_0, noise):
      tensor([[-1.2, -0.8], [-0.9, -1.1], ...])
      >>> Generated samples (π_1, target):
      tensor([[0.9, 1.2], [1.1, 0.8], ...])
      
    • 图像到图像翻译:
      >>> Source samples (π_0):
      tensor([[-1.3, -0.7], [-1.0, -1.2], ...])
      >>> Translated samples (π_1):
      tensor([[1.0, 0.9], [0.8, 1.1], ...])
      

总结

训练代码
  • 实现了 Rectified Flow 的训练和 Reflow 过程。
  • 使用合成 2D 数据,优化目标是让 v θ ( X t , t ) v_\theta(X_t, t) vθ(Xt,t) 接近 X 1 − X 0 X_1 - X_0 X1X0
  • 支持递归校正,路径逐步变直。
应用代码
  • 实现了图像生成和图像到图像翻译。
  • 从噪声生成目标分布样本,或从源域翻译到目标域。
  • 提供了如何修改为真实图像数据的说明。
实际应用中的注意事项
  • 模型容量:对于图像数据,建议使用更大的网络(如 U-Net)。
  • 采样步数num_steps 影响生成质量,建议调整。
  • Reflow 次数 K K K 越大,路径越直,但训练成本增加。

这些代码提供了一个可运行的起点,可以根据具体任务进一步优化。

图像翻译的问题


问题背景

在 Rectified Flow 中,图像到图像翻译的目标是从源域分布 π 0 \pi_0 π0(如人脸图像)变换到目标域分布 π 1 \pi_1 π1(如猫脸图像)。代码中的 image_to_image_translation 函数看起来很简单:

def image_to_image_translation(model, source_samples, num_steps=100):
    translated_samples = torch.zeros_like(source_samples)
    for i in range(source_samples.shape[0]):
        translated_samples[i] = sample_rectified_flow(model, source_samples[i], num_steps)
    print(">>> Source samples (π_0):")
    print(source_samples)
    print(">>> Translated samples (π_1):")
    print(translated_samples)
    return translated_samples

但你提出了一个关键问题:模型如何知道目标域是猫脸? 答案在于训练阶段的设计和漂移场 v θ ( z , t ) v_\theta(z, t) vθ(z,t) 的学习过程。image_to_image_translation 只是推理阶段的代码,真正的“知识”是在训练阶段通过数据和优化目标嵌入到模型中的。


1. 训练阶段:如何让模型知道目标是猫脸?

Rectified Flow 的训练目标是学习一个漂移场 v θ ( z , t ) v_\theta(z, t) vθ(z,t),使得 ODE d Z t = v θ ( Z t , t ) d t \mathrm{d}Z_t = v_\theta(Z_t, t) \mathrm{d}t dZt=vθ(Zt,t)dt 能够将样本从 π 0 \pi_0 π0 变换到 π 1 \pi_1 π1。在图像到图像翻译任务中, π 0 \pi_0 π0 是人脸分布, π 1 \pi_1 π1 是猫脸分布。以下是训练阶段的关键步骤:

1.1 数据准备
  • 源域 π 0 \pi_0 π0:收集人脸图像数据集(比如 CelebA)。
  • 目标域 π 1 \pi_1 π1:收集猫脸图像数据集(比如 AFHQ 中的猫脸部分)。
  • 耦合 ( X 0 , X 1 ) (X_0, X_1) (X0,X1):Rectified Flow 通常假设 ( X 0 , X 1 ) ∼ π 0 × π 1 (X_0, X_1) \sim \pi_0 \times \pi_1 (X0,X1)π0×π1,即 X 0 X_0 X0 X 1 X_1 X1 是独立采样的(无配对数据)。这意味着人脸和猫脸不需要一一对应。
1.2 训练目标

Rectified Flow 的优化目标是:
min ⁡ v θ ∫ 0 1 E [ ∥ ( X 1 − X 0 ) − v θ ( X t , t ) ∥ 2 2 ] d t , X t = t X 1 + ( 1 − t ) X 0 \min_{v_\theta} \int_0^1 \mathbb{E}\left[ \| (X_1 - X_0) - v_\theta(X_t, t) \|_2^2 \right] \mathrm{d}t, \quad X_t = t X_1 + (1 - t) X_0 vθmin01E[(X1X0)vθ(Xt,t)22]dt,Xt=tX1+(1t)X0

  • X 0 ∼ π 0 X_0 \sim \pi_0 X0π0(人脸), X 1 ∼ π 1 X_1 \sim \pi_1 X1π1(猫脸)。
  • X t X_t Xt 是线性插值路径,表示从人脸到猫脸的中间状态。
  • X 1 − X 0 X_1 - X_0 X1X0 是目标速度,表示从人脸到猫脸的“方向”。
  • v θ ( X t , t ) v_\theta(X_t, t) vθ(Xt,t) 是模型预测的速度。

通过优化这个目标,模型学习到的 v θ ( z , t ) v_\theta(z, t) vθ(z,t) 会在每个时间 t t t 和位置 z z z 处,预测一个速度向量,指向从人脸分布到猫脸分布的方向。

1.3 训练过程
  • 采样:从 π 0 \pi_0 π0 π 1 \pi_1 π1 采样 ( X 0 , X 1 ) (X_0, X_1) (X0,X1) 对,比如 X 0 X_0 X0 是一张人脸图像, X 1 X_1 X1 是一张猫脸图像。
  • 插值:计算 X t = t X 1 + ( 1 − t ) X 0 X_t = t X_1 + (1 - t) X_0 Xt=tX1+(1t)X0,这是从人脸到猫脸的线性插值。
  • 优化:让 v θ ( X t , t ) v_\theta(X_t, t) vθ(Xt,t) 接近 X 1 − X 0 X_1 - X_0 X1X0,即从人脸到猫脸的变换方向。
  • Reflow:通过递归校正(Reflow),使路径更直,最终接近一步变换。
1.4 模型如何“知道”目标是猫脸?
  • 通过数据:训练数据中, X 1 X_1 X1 始终来自猫脸分布 π 1 \pi_1 π1。模型通过优化目标,学习到 X 1 − X 0 X_1 - X_0 X1X0 的统计规律,即从人脸到猫脸的变换方向。
  • 统计学习 v θ ( z , t ) v_\theta(z, t) vθ(z,t) 是一个神经网络,它在训练中看到大量 ( X 0 , X 1 ) (X_0, X_1) (X0,X1) 对,逐渐学会了从人脸特征( π 0 \pi_0 π0)到猫脸特征( π 1 \pi_1 π1)的映射。
  • 边际分布:Rectified Flow 保证 Z 1 ∼ π 1 Z_1 \sim \pi_1 Z1π1(猫脸分布),因为训练目标确保了 Z t Z_t Zt 的边际分布从 π 0 \pi_0 π0 演变为 π 1 \pi_1 π1

直观解释

  • 想象 v θ ( z , t ) v_\theta(z, t) vθ(z,t) 是一个“导航系统”。在训练时,它看到很多人脸-猫脸对 ( X 0 , X 1 ) (X_0, X_1) (X0,X1),学会了从人脸到猫脸的“方向”(比如将眼睛变大、添加胡须等特征)。
  • 在推理时,给定一张人脸 Z 0 Z_0 Z0,模型会根据 v θ ( z , t ) v_\theta(z, t) vθ(z,t) 的指引,逐步将 Z 0 Z_0 Z0 变换为 Z 1 Z_1 Z1,而 Z 1 Z_1 Z1 会自然落在猫脸分布 π 1 \pi_1 π1 上。

2. 推理阶段:image_to_image_translation 的作用

现在回到 image_to_image_translation 函数:

def image_to_image_translation(model, source_samples, num_steps=100):
    translated_samples = torch.zeros_like(source_samples)
    for i in range(source_samples.shape[0]):
        translated_samples[i] = sample_rectified_flow(model, source_samples[i], num_steps)
    print(">>> Source samples (π_0):")
    print(source_samples)
    print(">>> Translated samples (π_1):")
    print(translated_samples)
    return translated_samples
2.1 函数的作用
  • 输入
    • model:训练好的漂移场 v θ ( z , t ) v_\theta(z, t) vθ(z,t)
    • source_samples:源域样本(人脸图像),即 Z 0 ∼ π 0 Z_0 \sim \pi_0 Z0π0
    • num_steps:ODE 求解的步数。
  • 输出
    • translated_samples:目标域样本(猫脸图像),即 Z 1 ∼ π 1 Z_1 \sim \pi_1 Z1π1
2.2 内部调用 sample_rectified_flow

sample_rectified_flow 函数用 Euler 方法求解 ODE:

def sample_rectified_flow(model, z0, num_steps=100):
    z_t = z0.clone()
    dt = 1.0 / num_steps
    for step in range(num_steps):
        t = torch.tensor(step * dt)
        v = model(z_t, t)
        z_t = z_t + v * dt  # Euler 积分
    return z_t
  • 输入:初始样本 Z 0 Z_0 Z0(人脸图像)。
  • 过程
    • t = 0 t=0 t=0 t = 1 t=1 t=1,逐步更新 Z t Z_t Zt
    • 每一步用 v θ ( Z t , t ) v_\theta(Z_t, t) vθ(Zt,t) 计算速度,沿速度方向移动。
  • 输出 Z 1 Z_1 Z1(猫脸图像)。
2.3 为什么 Z 1 Z_1 Z1 是猫脸?
  • 训练的知识 v θ ( z , t ) v_\theta(z, t) vθ(z,t) 在训练时已经学会了从人脸到猫脸的变换方向。
  • ODE 的演化:从 Z 0 Z_0 Z0(人脸)开始, v θ ( z , t ) v_\theta(z, t) vθ(z,t) 提供了一个连续的变换路径,最终将 Z 0 Z_0 Z0 推向 π 1 \pi_1 π1(猫脸分布)。
  • 边际分布保证:Rectified Flow 的理论性质( Law ( Z t ) = Law ( X t ) \text{Law}(Z_t) = \text{Law}(X_t) Law(Zt)=Law(Xt))确保 Z 1 Z_1 Z1 的分布与 π 1 \pi_1 π1 一致。

直观解释

  • 给定一张人脸图像 Z 0 Z_0 Z0,模型会根据 v θ ( z , t ) v_\theta(z, t) vθ(z,t) 的指引,逐步调整图像的特征(比如将人脸的眼睛变大、鼻子变小、添加胡须等),最终生成一张猫脸图像 Z 1 Z_1 Z1
  • 模型并不“知道”猫脸是什么,而是通过训练数据学习到了猫脸的统计特征( π 1 \pi_1 π1),并将 Z 0 Z_0 Z0 推向这个分布。

3. 实际实现中的细节

3.1 网络结构
  • 训练代码中的 VelocityNet
    • 目前是一个简单的 MLP,适合 2D 点数据。
    • 对于图像数据,建议使用卷积神经网络(CNN),如 U-Net:
      class VelocityNet(nn.Module):
          def __init__(self, channels=3):
              super(VelocityNet, self).__init__()
              self.unet = UNet(channels, channels)  # 输入和输出都是图像
              self.time_embedding = nn.Linear(1, 64)  # 时间嵌入
      
          def forward(self, z, t):
              t_emb = torch.sin(self.time_embedding(t.view(-1, 1)))
              return self.unet(z, t_emb)  # U-Net 处理图像
      
    • U-Net 适合处理图像数据,能够捕捉空间特征。
3.2 数据预处理
  • 人脸和猫脸图像
    • 通常需要调整到相同大小(比如 64 × 64 64 \times 64 64×64 256 × 256 256 \times 256 256×256)。
    • 归一化到 [ − 1 , 1 ] [-1, 1] [1,1]
      transform = transforms.Compose([
          transforms.Resize((64, 64)),
          transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化到 [-1, 1]
      ])
      
3.3 训练数据
  • 无配对数据:Rectified Flow 不需要人脸和猫脸一一对应,只需要两个分布的样本。
  • 训练代码修改
    • 替换 generate_synthetic_data 为真实数据集加载:
      from torchvision import datasets, transforms
      
      # 加载人脸和猫脸数据集
      transform = transforms.Compose([...])
      pi_0_dataset = datasets.CelebA(root='data/', transform=transform)  # 人脸
      pi_1_dataset = datasets.ImageFolder(root='data/afhq/cat/', transform=transform)  # 猫脸
      pi_0_loader = torch.utils.data.DataLoader(pi_0_dataset, batch_size=128, shuffle=True)
      pi_1_loader = torch.utils.data.DataLoader(pi_1_dataset, batch_size=128, shuffle=True)
      
3.4 推理时的输入
  • source_samples
    • 在推理时,source_samples 是人脸图像(形状为 (batch_size, channels, height, width))。
    • 确保与训练时的预处理一致(归一化到 [ − 1 , 1 ] [-1, 1] [1,1])。
3.5 可视化
  • 保存翻译结果
    • 推理后,translated_samples 是猫脸图像,可以用 torchvision.utils.save_image 保存:
      from torchvision.utils import save_image
      
      # 反归一化
      translated_samples = (translated_samples + 1) / 2  # 从 [-1, 1] 转换到 [0, 1]
      save_image(translated_samples, "translated_cats.png", nrow=8)
      
    • 注:Pyodide 环境中无法保存文件,需在本地运行。

4. 改进:如何增强翻译效果?

4.1 引入特征空间损失
  • 直接用像素空间的 X 1 − X 0 X_1 - X_0 X1X0 可能不够鲁棒,因为人脸和猫脸的像素差异可能很大。
  • 论文中提到了一种改进方法(第5.2节):在特征空间中优化:
    min ⁡ v θ ∫ 0 1 E [ ∥ h ( X 1 ) − h ( X 0 ) − v θ ( h ( X t ) , t ) ∥ 2 2 ] d t \min_{v_\theta} \int_0^1 \mathbb{E}\left[ \| h(X_1) - h(X_0) - v_\theta(h(X_t), t) \|_2^2 \right] \mathrm{d}t vθmin01E[h(X1)h(X0)vθ(h(Xt),t)22]dt
    • h ( ⋅ ) h(\cdot) h() 是一个预训练的特征提取器(比如 VGG 的中间层)。
    • h ( X 0 ) h(X_0) h(X0) h ( X 1 ) h(X_1) h(X1) 是人脸和猫脸的特征表示。
    • 这样可以捕捉更高层次的语义差异(比如眼睛形状、毛发纹理)。
4.2 引入循环一致性
  • Rectified Flow 默认是单向变换( π 0 → π 1 \pi_0 \to \pi_1 π0π1)。可以训练一个反向模型( π 1 → π 0 \pi_1 \to \pi_0 π1π0),并引入循环一致性损失:
    • 从人脸 Z 0 Z_0 Z0 翻译到猫脸 Z 1 Z_1 Z1
    • Z 1 Z_1 Z1 翻译回人脸 Z 0 ′ Z_0' Z0
    • 最小化 ∥ Z 0 − Z 0 ′ ∥ 2 2 \| Z_0 - Z_0' \|_2^2 Z0Z022
4.3 条件输入
  • 如果有条件信息(比如目标域的标签“猫”),可以将其嵌入到 v θ ( z , t ) v_\theta(z, t) vθ(z,t) 中:
    class VelocityNet(nn.Module):
        def forward(self, z, t, condition):
            t_emb = torch.sin(self.time_embedding(t.view(-1, 1)))
            cond_emb = self.condition_embedding(condition)  # 条件嵌入
            return self.unet(z, t_emb, cond_emb)
    
    • 但 Rectified Flow 默认不依赖条件输入,完全通过 π 1 \pi_1 π1 的数据分布学习目标。

5. 总结:模型如何知道目标是猫脸?

  1. 训练阶段

    • 模型通过 ( X 0 , X 1 ) ∼ π 0 × π 1 (X_0, X_1) \sim \pi_0 \times \pi_1 (X0,X1)π0×π1 的训练数据学习从人脸到猫脸的变换方向。
    • v θ ( z , t ) v_\theta(z, t) vθ(z,t) 捕捉了从 π 0 \pi_0 π0 π 1 \pi_1 π1 的统计规律。
  2. 推理阶段

    • image_to_image_translation 用训练好的 v θ ( z , t ) v_\theta(z, t) vθ(z,t),从人脸 Z 0 Z_0 Z0 开始,沿 v θ ( z , t ) v_\theta(z, t) vθ(z,t) 的方向演化,最终到达猫脸分布 π 1 \pi_1 π1
    • 模型并不显式“知道”目标是猫脸,而是通过训练数据和边际分布的约束,自然将 Z 0 Z_0 Z0 推向 π 1 \pi_1 π1
  3. 改进方向

    • 用特征空间损失增强语义一致性。
    • 引入循环一致性或条件输入提高翻译质量。
直观类比
  • 训练时,模型就像一个“画家”,看到很多人脸-猫脸对,学会了如何把人脸画成猫脸(比如把鼻子变小、添加胡须)。
  • 推理时,给定一张人脸,模型会一步步“修改”这张画,最终画出一张猫脸。

后记

2025年4月5日16点34分于上海,在grok 3大模型辅助下完成。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值