神经常微分方程——理解篇

1 常微分方程

常微分方程只包含单个自变量 t t t,未知函数 y ( t ) y(t) y(t)和未知函数的导数 y ′ ( t ) y'(t) y(t)的等式,例如: y ′ ( t ) = 2 t y'(t)=2t y(t)=2t。可以写成如下通用的形式:

y ( 0 ) = y 0 ; d y d t ( t ) = f θ ( t , y ( t ) ) (1) y(0)=y_{0}; \frac{dy}{dt}(t)=f_{\theta}(t, y(t)) \tag{1} y(0)=y0;dtdy(t)=fθ(t,y(t))(1)

其中, f θ ( t , y ( t ) ) f_{\theta}(t, y(t)) fθ(t,y(t))表示由 t t t y ( t ) y(t) y(t)组成的某个函数,函数的常数为参数 θ \theta θ

求解常微分方程的方式有两种:

1)求出解析解,得到 y ( t ) y(t) y(t)的具体形式,例如: y ′ ( t ) = 2 t y'(t)=2t y(t)=2t的通解为 y ( t ) = t 2 + C y(t)=t^2+C y(t)=t2+C, C C C y ( 0 ) = y 0 y(0)=y_0 y(0)=y0代入确定;

2)求出数值解,例如:欧拉法,迭代求解出 y ( t ) y(t) y(t)在某 t t t位置的函数值,其核心思想是用切线逐步逼近求解函数:
y n + 1 = y n + f θ ( t n , y n ) ( t n + 1 − t n ) (2) y_{n+1}=y_{n}+f_{\theta}(t_n, y_{n})(t_{n+1}-t_n)\tag{2} yn+1=yn+fθ(tn,yn)(tn+1tn)(2)

对于实际问题,大部分情况下都无法得到解析解。因此,求出数值解就成了唯一可行的方式。如下面示意图所示,在知道常微分方程的形式(公式(1)),我们便可以 ( t 0 , y 0 ) → ( t 1 , y 1 ) → . . . (t_0, y_0)\rightarrow(t_1, y_1)\rightarrow ... (t0,y0)(t1,y1)...迭代的出解出目标位置处的解析解。

在这里插入图片描述

2 从残差网络到常微分方程

一层残差网络可形式化为下图:
在这里插入图片描述
我们可以将上面的残差网络表式成以下方程:

h t + 1 = h t + f θ ( h t ) (3) h_{t+1}=h_t+f_{\theta}(h_t) \tag{3} ht+1=ht+fθ(ht)(3)

重新变化式(2)与式(3)得:

常微分离散化形式 y n + 1 − y n t n + 1 − t n = f θ ( t n , y n ) \frac{y_{n+1}-y_{n}}{t_{n+1}-t_n}=f_{\theta}(t_n, y_{n}) tn+1tnyn+1yn=fθ(tn,yn)
残差网络形式 h t + 1 − h t 1 = f θ ( h t ) \frac{h_{t+1}-h_t}{1}=f_{\theta}(h_t) 1ht+1ht=fθ(ht)

可知残差网络是状态步长为1,且不显式地包含自变量 t t t的常微分方程。要使得残差网络能表示更一般的常微分方程,可以设计如下的网络结构(主要在输入增加变量t):

在这里插入图片描述
从上可以分析得出,残差网络(本质上和欧拉法一样)可用来计算常微分方程的数值解。通过给 f θ ( t , h t ) f_{\theta}(t, h_t) fθ(t,ht)乘上可变的步长 d t dt dt,便得到更加一般的形式:
h t + 1 = h t + f θ ( t , h t ) d t h_{t+1}=h_t+f_{\theta}(t, h_t)dt ht+1=ht+fθ(t,ht)dt

但是,用欧拉法与残差网络来计算常微分方程的数值解太过粗糙。如果用一个抽象的概念来代替欧拉法与残差网络,例如ODESolver网络,其中ODESolver是一个函数,它提供了ODE的解决方法,其精度比欧拉法高得多。这就是神经常微分方程(把ODESolver当成一个黑盒):

在这里插入图片描述

3 怎么训练神经常微分方程

在算法 1 1中,陈天琦等研究者展示了如何借助另一个 OED Solver 一次性求出反向传播的各种梯度和更新量。要理解算法 1,首先我们要熟悉 ODESolver 的表达方式。例如在 ODEnet 的前向传播中,求解过程可以表示为 ODEsolver(z(t_0), f, t_0, t_1, θ),我们可以理解为从 t_0 时刻开始令 z(t_0) 以变化率 f 进行演化,这种演化即 f 在 t 上的积分,ODESolver 的目标是通过积分求得 z(t_1)。
在这里插入图片描述
同样我们能以这种方式理解算法 1,我们的目的是利用 ODESolver 从 z(t_1) 求出 z(t_0)、从 a(t_1) 按照方程 4 积出 a(t_0)、从 0 按照方程 5 积出 dL/dθ。最后我们只需要使用 dL/dθ 更新神经网络 f(z(t), t, θ) 就完成了整个反向传播过程。

在这里插入图片描述

4. 伴随法BP的推导

动态微分系统的数据集中的元素可以由<时间,状态>对表示,标记为 ( z , t ) (z, t) (z,t)

假定我们要学习的动态微分系统的形式为:
d z d t = f ( z ( t ) , t ) (4) \frac{dz}{dt}=f(z(t), t) \tag{4} dtdz=f(z(t),t)(4)
用于学习的观测数据集为 { ( z 0 , t 0 ) , ( z 1 , t 1 ) , . . . , ( z N , t N ) } \{(z_0, t_0), (z_1, t_1), ..., (z_N, t_N)\} {(z0,t0),(z1,t1),...,(zN,tN)}

我们利用可学习的网络模型 f ^ ( z , t , θ ) \hat{f}(z, t, \theta) f^(z,t,θ)来近似动态系统真实的微分函数 f ( z , t ) f(z, t) f(z,t)

若以某状态 z 0 z_0 z0为起始状态(假定 z 0 z_0 z0在数据集中的时间标记为 t 0 t_0 t0),我们验证所学到的网络模型是否很好的近似了真实的微分函数 f ( z , t ) f(z, t) f(z,t)的基本方式为:

  • z 0 z_0 z0作为起始状态,以网络模型作为微分函数,利用ODE Solver求得 t 1 t_1 t1时刻的状态 z ^ 1 \hat{z}_1 z^1
  • 从数据集中找到时间标记为 t 1 t_1 t1的状态 z 1 z_1 z1
  • 利用 z ^ 1 \hat{z}_1 z^1 z 1 z_1 z1的差异来度量学习效果,一种可能的损失函数为 L ( z ^ 1 ) = 1 2 ∣ ∣ z ^ 1 − z 1 ∣ ∣ 2 2 L(\hat{z}_1)=\frac{1}{2}||\hat{z}_1-z_1||_2^2 L(z^1)=21∣∣z^1z122

实际数据不可能只有一个样本,那么损失函数的一般形式采用如下所示的均方误差(MSE):

L = 1 N ∑ i = 1 N ∣ ∣ z ^ i − z i ∣ ∣ 2 2 = 1 N ∑ i = 1 N ∣ ∣ ∫ t i − 1 t i f ^ ( z , t , θ ) d t − z i ∣ ∣ 2 2 = 1 N ∑ i = 1 N ∣ ∣ O D E S o l v e r ( z i − 1 , f , t i − 1 , t i , θ ) − z i ∣ ∣ 2 2 (5) L=\frac{1}{N}\sum_{i=1}^{N}||\hat{z}_i-z_i||_2^2=\frac{1}{N}\sum_{i=1}^{N}||\int_{t_{i-1}}^{t_i}\hat{f}(z, t, \theta)dt-z_i||_2^2=\frac{1}{N}\sum_{i=1}^{N}||ODESolver(z_{i-1}, f, t_{i-1}, t_i, \theta)-z_i||_2^2 \tag{5} L=N1i=1N∣∣z^izi22=N1i=1N∣∣ti1tif^(z,t,θ)dtzi22=N1i=1N∣∣ODESolver(zi1,f,ti1,ti,θ)zi22(5)

我们要更新的模型参数为 θ \theta θ,因此,我们最终需要得到的是 d L d θ \frac{dL}{d\theta} dθdL。但是,从 z ( t i − 1 ) z(t_{i-1}) z(ti1) z ( t i ) z(t_i) z(ti)利用了ODESolver算子,一个python版本的ODESolver算子如下所示:

def ode_solve(z0, t0, t1, f):
    """
    Simplest Euler ODE initial value solver
    """
    h_max = 0.05
    n_steps = math.ceil((abs(t1 - t0)/h_max).max().item())

    h = (t1 - t0)/n_steps
    t = t0
    z = z0

    for i_step in range(n_steps):
        z = z + h * f(z, t)
        t = t + h
    return z

当系统是时变的,输入 t t t是作为网络的输入,当系统是时不变的, t t t在此处对函数的输出没有影响。一个时不变的 f f f例子给出如下:

class LinearODEF(ODEF):
    def __init__(self, W):
        super(LinearODEF, self).__init__()
        self.lin = nn.Linear(2, 2, bias=False)
        self.lin.weight = nn.Parameter(W)

    def forward(self, x, t):
        return self.lin(x)

虽然,构建的网络模型基于函数 f f f,但是完整的网络是基 f f fODESolver,并采用一个设定时间步 d t dt dt的累积函数 z ^ n = N e t ( z 0 , θ , t 0 , t 1 ) = z 0 + ∑ i = 1 n f ( z ^ i − 1 , t i − 1 , θ ) d t \hat{\mathbf{z}}_{n}=Net(\mathbf{z}_0, \theta, t_0, t_1)=\mathbf{z}_0+\sum_{i=1}^{n}f(\hat{\mathbf{z}}_{i-1}, t_{i-1}, \theta)dt z^n=Net(z0,θ,t0,t1)=z0+i=1nf(z^i1,ti1,θ)dt,其中, n = t 1 − t 0 d t n=\frac{t_1-t_0}{dt} n=dtt1t0。为了得到 d L d θ \frac{dL}{d\theta} dθdL,我们需要采用与反向传播算法一样的链式法则,同时需要计算 d L d z ^ , d L d t \frac{dL}{d\hat{\mathbf{z}}},\frac{dL}{dt} dz^dL,dtdL

a ( t ) = d L d z ^ ( t ) \mathbf{a}(t)=\frac{dL}{d\hat{\mathbf{z}}(t)} a(t)=dz^(t)dL,我们有:
d a ( t ) d t = − a ( t ) ∂ f ( z ( t ) , t , θ ) ∂ z ( t ) (6) \frac{d\mathbf{a}(t)}{dt}=-\mathbf{a}(t)\frac{\partial f(\mathbf{z}(t),t,\theta)}{\partial \mathbf{z}(t)}\tag{6} dtda(t)=a(t)z(t)f(z(t),t,θ)(6)

式(6)的证明
证明
其中第二行中分子的第二项证明(上面第三行为泰勒展开):
z ( t + ϵ ) = ∫ t t + ϵ f ( z ( t ) , t , θ ) d t + z ( t ) = T ϵ ( z ( t ) , t ) \mathbf{z}(t+\epsilon)=\int_{t}^{t+\epsilon}f(\mathbf{z}(t),t,\theta)dt+\mathbf{z}(t)=T_{\epsilon}(\mathbf{z}(t),t) z(t+ϵ)=tt+ϵf(z(t),t,θ)dt+z(t)=Tϵ(z(t),t)
d L ∂ z ( t ) = d L d z ( t + ϵ ) d t + ϵ d z ( t ) ⇒ a ( t ) = a ( t + ϵ ) ∂ T ϵ ( z ( t ) , t ) ∂ z ( t ) \frac{dL}{\partial \mathbf{z}(t)}=\frac{dL}{d\mathbf{z}(t+\epsilon)}\frac{d\mathbf{t+\epsilon}}{d\mathbf{z}(t)}\Rightarrow \mathbf{a}(t)=\mathbf{a}(t+\epsilon)\frac{\partial T_{\epsilon}(\mathbf{z}(t), t)}{\partial \mathbf{z}(t)} z(t)dL=dz(t+ϵ)dLdz(t)dt+ϵa(t)=a(t+ϵ)z(t)Tϵ(z(t),t)

假如,我们从 { ( z 0 , t 0 ) , ( z 1 , t 1 ) , . . . , ( z N , t N ) } \{(\mathbf{z}_0, t_0), (\mathbf{z}_1, t_1), ..., (\mathbf{z}_N, t_N)\} {(z0,t0),(z1,t1),...,(zN,tN)}中抽取一个样本 ( z N − 1 , t N − 1 ) → ( z N , t N ) (\mathbf{z}_{N-1}, t_{N-1})\rightarrow (\mathbf{z}_{N}, t_N) (zN1,tN1)(zN,tN),表示以 z N − 1 \mathbf{z}_{N-1} zN1为起始状态,经过时间 t N − t N − 1 t_N-t_{N-1} tNtN1,动态系统(此处假设为时不变系统)的状态变为 z N \mathbf{z}_N zN。由内嵌ODESolver算子的神经网络模型得到的状态为 z ^ N \hat{\mathbf{z}}_N z^N

由于 z ^ N \hat{\mathbf{z}}_N z^N相当于一般学习任务的输出 y y y,所以 d L d z ^ ( t N ) = d d z ^ N ( 1 2 ∣ ∣ z ^ N − z N ∣ ∣ 2 2 ) \frac{dL}{d\hat{\mathbf{z}}(t_N)}=\frac{d}{d\hat{\mathbf{z}}_N}(\frac{1}{2}||\hat{\mathbf{z}}_N-\mathbf{z}_N||_2^2) dz^(tN)dL=dz^Nd(21∣∣z^NzN22),直接能得到结果。

重要的是怎么得到 d L z ^ N − 1 \frac{dL}{\hat{\mathbf{z}}_{N-1}} z^N1dL(从 z ^ N − 1 \hat{\mathbf{z}}_{N-1} z^N1 z ^ N \hat{\mathbf{z}}_N z^N的过程是一个多次累积的ODESolver算法),根据公式(6)有:
d L d z ^ N − 1 = a ( t N − 1 ) = a ( t N ) + ∫ t N t N − 1 d a ( t ) d t d t = a ( t N ) − ∫ t N t N − 1 a ( t ) T ∂ f ( z ^ ( t ) , t , θ ) ∂ z ^ ( t ) (7) \frac{dL}{d\hat{\mathbf{z}}_{N-1}}=\mathbf{a}(t_{N-1})=\mathbf{a}(t_{N})+\int_{t_N}^{t_{N-1}}\frac{d\mathbf{a}(t)}{dt}dt=\mathbf{a}(t_{N})-\int_{t_{N}}^{t_{N-1}}\mathbf{a}(t)^T\frac{\partial f(\hat{\mathbf{z}}(t), t, \theta)}{\partial \hat{\mathbf{z}}(t)}\tag{7} dz^N1dL=a(tN1)=a(tN)+tNtN1dtda(t)dt=a(tN)tNtN1a(t)Tz^(t)f(z^(t),t,θ)(7)

我们可以看到,公式(7)用ODESolver算法就将梯度 d L d z ^ N \frac{dL}{d\hat{\mathbf{z}}_N} dz^NdL反向传播给 d L d z ^ N − 1 \frac{dL}{d\hat{\mathbf{z}}_{N-1}} dz^N1dL

如下图所示,数据是按 t 0 → t 1 t_0\rightarrow t_1 t0t1, t 1 → t 2 t_1\rightarrow t_2 t1t2,…, t N − 1 → t N t_{N-1}\rightarrow t_N tN1tN组合作为训练数据的,但是 d L d z ^ ( t N ) \frac{dL}{d\hat{\mathbf{z}}(t_N)} dz^(tN)dL的梯度可以传播给 t N − 1 t_{N-1} tN1,也可以继续往后传播给 t N − 2 t_{N-2} tN2直至数据采集时的起点 t 0 t_0 t0。所有前方的数据都可以把梯度传播给它后方(下图右2为前,左为后)用于训练。
在这里插入图片描述
不要忘记了我们的目的,计算 d L d θ \frac{dL}{d\theta} dθdL
ODESolver算子以 z ^ i \hat{\mathbf{z}}_i z^i作为中间状态,我们根据链式法则有 a θ ( t ) = d L d θ = d L d z ^ d z ^ d θ = a ( t ) d z ^ d θ \mathbf{a}_{\theta}(t)=\frac{dL}{d\theta}=\frac{dL}{d\hat{\mathbf{z}}}\frac{d\hat{\mathbf{z}}}{d\theta}=\mathbf{a}(t)\frac{d\hat{\mathbf{z}}}{d\theta} aθ(t)=dθdL=dz^dLdθdz^=a(t)dθdz^

我们可以先求得 d L d θ = ∫ d d t d L d θ d t \frac{dL}{d\theta}=\int\frac{d}{dt}\frac{dL}{d\theta}dt dθdL=dtddθdLdt中的 d d t d L d θ \frac{d}{dt}\frac{dL}{d\theta} dtddθdL
d d t d L d θ = d d θ d L d z ^ d z ^ d t = d d θ a ( t ) f ( z ^ , t , θ ) = a ( t ) ∂ f ( z ^ , t , θ ) ∂ θ \frac{d}{dt}\frac{dL}{d\theta}=\frac{d}{d\theta}\frac{dL}{d\hat{\mathbf{z}}}\frac{d\hat{\mathbf{z}}}{dt}=\frac{d}{d\theta} \mathbf{a}(t)f(\hat{\mathbf{z}}, t, \theta)=\mathbf{a}(t)\frac{\partial f(\hat{\mathbf{z}}, t, \theta)}{\partial \theta} dtddθdL=dθddz^dLdtdz^=dθda(t)f(z^,t,θ)=a(t)θf(z^,t,θ)
a θ ( t N ) = 0 \mathbf{a}_{\theta}(t_N)=0 aθ(tN)=0,得:

a θ ( t N − 1 ) = d L d θ = − ∫ t N t N − 1 a ( t ) ∂ f ( z ^ , t , θ ) ∂ θ d t (8) \mathbf{a}_{\theta}(t_{N-1})=\frac{dL}{d\theta}=-\int_{t_N}^{t_{N-1}}\mathbf{a}(t)\frac{\partial f(\hat{\mathbf{z}}, t, \theta)}{\partial \theta}dt \tag{8} aθ(tN1)=dθdL=tNtN1a(t)θf(z^,t,θ)dt(8)

当然,还需要求(对于时变系统需要 a t ( t ) = d L d t = d L d z ^ d z ^ d t = a ( t ) f ( z ^ , t , θ ) \mathbf{a}_t(t)=\frac{dL}{dt}=\frac{dL}{d\hat{\mathbf{z}}}\frac{d\hat{\mathbf{z}}}{dt}=\mathbf{a}(t)f(\hat{\mathbf{z}}, t,\theta) at(t)=dtdL=dz^dLdtdz^=a(t)f(z^,t,θ),由于我们想求任意时刻的 a t ( t ) \mathbf{a}_t(t) at(t),而网络模型的forward只输出 t N t_N tN时刻的 f ( z N ^ , t N , θ ) f(\hat{\mathbf{z}_{N}}, t_N,\theta) f(zN^,tN,θ),依据此式,我们只能计算出 a t ( t N ) = a ( t N ) f ( z N ^ , t N , θ ) \mathbf{a}_t(t_N)=\mathbf{a}(t_N)f(\hat{\mathbf{z}_{N}}, t_N,\theta) at(tN)=a(tN)f(zN^,tN,θ)。,因此需要寻找其它计算形式:
我们可以令 d L d t = ∫ d d t d L d t d t , \frac{dL}{dt}=\int \frac{d}{dt}\frac{dL}{dt}dt, dtdL=dtddtdLdt,
d d t d L d t = d d t d L d z ^ d z ^ d t = a ( t ) ∂ f ( z ^ , t , θ ) ∂ t \frac{d}{dt}\frac{dL}{dt}=\frac{d}{dt}\frac{dL}{d\hat{\mathbf{z}}}\frac{d\hat{\mathbf{z}}}{dt}=\mathbf{a}(t)\frac{\partial f(\hat{\mathbf{z}}, t,\theta)}{\partial t} dtddtdL=dtddz^dLdtdz^=a(t)tf(z^,t,θ)
根据上面我们有 a t ( t N ) = a ( t N ) f ( z N ^ , t N , θ ) \mathbf{a}_t(t_N)=\mathbf{a}(t_N)f(\hat{\mathbf{z}_{N}}, t_N,\theta) at(tN)=a(tN)f(zN^,tN,θ)
则可得:
a t ( t N − 1 ) = d L d t N − 1 = a t ( t N ) − ∫ t N t N − 1 a ( t ) ∂ f ( z ^ ( t ) , t , θ ) ∂ t d t (9) \mathbf{a}_t(t_{N-1})=\frac{dL}{dt_{N-1}}=\mathbf{a}_t(t_N)-\int_{t_N}^{t_{N-1}}\mathbf{a}(t)\frac{\partial f(\hat{\mathbf{z}}(t),t,\theta)}{\partial t}dt\tag{9} at(tN1)=dtN1dL=at(tN)tNtN1a(t)tf(z^(t),t,θ)dt(9)

参考


  1. Chen R, Rubanova Y, Bettencourt J and Duvenaud D. Neural Ordinary Differential Equations(PDF). NeurIPS 2018. ↩︎

  2. Neural Ordinary Differential Equations(Jupyter notebook). ↩︎

### 经常微分方程(Neural ODEs)网络概述 经常微分方程(Neural Ordinary Differential Equations, Neural ODEs)是一种新型的深度学习框架,它通过连续动力学建模替代离散化的层结构。传统的经网络由一系列离散的层组成,而Neural ODE则定义了一个输入到输出的映射作为一组ODE的解[^4]。 #### 基本原理 Neural ODE的核心思想是将隐藏状态 \( h(t) \) 的演化描述为时间 \( t \) 上的一个连续函数,该函数由一个可学习的动力系统控制: \[ \frac{dh(t)}{dt} = f(h(t), t, θ), \] 其中 \( f \) 是一个由经网络参数化的时间依赖向量场,\( θ \) 表示网络的参数集合。为了得到最终的状态 \( h(T) \),可以通过数值求解器积分上述ODE获得结果[^5]。 ```python import torch from torchdiffeq import odeint class ODENet(torch.nn.Module): def __init__(self, dim): super(ODENet, self).__init__() self.linear = torch.nn.Linear(dim, dim) def forward(self, t, y): # 定义导数关系 return self.linear(y) def run_ode(): func = ODENet(dim=10) initial_state = torch.randn(10) # 初始条件 timespan = torch.tensor([0., 5.]) # 积分区间 solution = odeint(func, initial_state, timespan) # 数值求解ODE return solution[-1] # 返回t=5时刻的结果 result = run_ode() print(result) ``` 此代码片段展示了如何利用 `torchdiffeq` 库实现简单的Neural ODE模型[^6]。 #### 实现方法的优势与挑战 相比传统经网络,Neural ODE具有以下优势: - **内存效率高**:由于不需要存储每一层中间激活值用于反向传播,因此显著减少了GPU显存占用。 - **灵活调整精度**:可以根据需求自由调节数值积分器的误差容忍度以平衡速度和准确性。 然而也存在一些挑战: - 计算复杂度较高,尤其是对于复杂的矢量场; - 需要精心设计适合特定任务的ODE形式以及初始化策略[^7]。 #### 应用场景举例 1. **时间序列预测** 类似于液态经网络[Liquid State Machine][^8],Neural ODE特别适用于处理具有内在连续变化规律的数据集,比如股票价格波动或者天气预报等。 2. **图像分类** ResNet架构可以被看作是对离散版本的Neural ODE近似的实例;相应地,采用真正的连续型表示可能带来更优性能表现[^9]。 3. **概率分布生成** Normalizing Flows是一类强大的生成对抗网络变体,它们能够借助Neural ODE构建更加平滑且多样化的样本空间转换路径[^10]。
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

windSeS

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值