SMM在机器学习和深度学习中的应用

状态空间方程

状态空间方程提供了一种将系统的数学模型表达为一组一阶微分方程的方法,描述了系统状态的演变和输出方程。

由两个主要方程组成:

  1. 状态方程:
    描述系统状态随时间的变化。对于线性系统,状态方程可以写为:
    x ˙ ( t ) = A x ( t ) + B u ( t ) \dot{x}(t)=Ax(t)+Bu(t) x˙(t)=Ax(t)+Bu(t)
    其中 x ˙ ( t ) \dot{x}(t) x˙(t)表示状态向量 x x x关于时间的导数, A A A是系统矩阵,描述系统状态之间的作用, B B B是输入矩阵,描述外部输入 u ( t ) u(t) u(t)如何影响系统状态。

  2. 输出方程:
    描述系统输出与系统状态之间的关系,对于线性系统,输出方程可以写为:
    y ( t ) = C x ( t ) + D u ( t ) {y}(t)=C{x}(t)+Du(t) y(t)=Cx(t)+Du(t)
    其中, y ( t ) y(t) y(t)是输出向量, C C C是输出矩阵,描述系统状态如何影响输出,而 D D D是直接传递矩阵,描述输入如何直接影响输出。

状态空间方程在机器学习和深度学习中的应用

循环神经网络(RNNs)

RNN的核心思想与状态空间模型紧密相关,在RNN中,网络的隐藏状态可以被视为动态系统的“状态”,而网络的权重定义了状态如何随时间演化(状态转移)以及如何生成输出(输出方程)。

对于一个基本的RNN单元,给定时刻 t t t的输入 x t x_t xt和上一时刻的隐藏状态 h t − 1 h_{t-1} ht1,隐藏状态的更新可以表示为:
h t = f ( W h h h t − 1 + W x h x t + b h ) h_t=f(W_{hh}h_{t-1}+W_{xh}x_t+b_h) ht=f(Whhht1+Wxhxt+bh)
其中, f f f是激活函数, W h h W_{hh} Whh W x h W_{xh} Wxh分别是隐藏状态和输入到隐藏状态的权重矩阵 b h b_h bh是偏置项。

输出 y t y_t yt于时刻 t t t可以通过对当前隐藏状态 h t h_t ht应用另一组权重 W h y W_{hy} Why和偏置 b y b_y by来计算:
y t = g ( W h y h t + b y ) y_t=g(W_{hy}h_t+b_y) yt=g(Whyht+by)
其中 g g g是输出层的激活函数。

序列模型和时间预测

序列模型中的状态空间方程:

  1. 状态方程:
    x t + 1 = F t X t + G t u t + w t x_{t+1}=F_tX_t+G_tu_t+w_t xt+1=FtXt+Gtut+wt

  2. 观测方程:
    y t = H t x t + v t y_t=H_tx_t+v_t yt=Htxt+vt

SNN中的状态空间方程

τ m d V ( t ) d t = − [ V ( t ) − V r e s t ] + R I ( t ) τ_m\frac{dV(t)}{dt}=-[V(t)-V_{rest}]+RI(t) τmdtdV(t)=[V(t)Vrest]+RI(t)

  • V ( t ) V(t) V(t)是在时间 t t t的膜电位
  • τ m τ_m τm是膜时间常数,决定了电位衰减的速度。
  • V r e s t V_{rest} Vrest是静息电位,即没有输入时神经元的稳态电位。
  • R R R是膜电阻
  • I ( t ) I(t) I(t)是输入电流

当膜电位 V ( t ) V(t) V(t)到达阈值 V t h r e s h V_{thresh} Vthresh时,神经元发放一个脉冲,并将膜电位重置到某个值,通常是 V r e s e t V_{reset} Vreset。这个过程可以用以下条件和重置方程表示:
if  V ( t ) ≥ V thresh  then  V ( t ) ← V reset \text{if } V(t) \geq V_{\text{thresh}} \text{ then } V(t) \leftarrow V_{\text{reset}} if V(t)Vthresh then V(t)Vreset

深度序列模型

针对序列数据的深度学习模型可被视为围绕循环、卷积或注意力等简单机制建立的序列到序列转换。

定义 1.1(非正式)。作者使用序列模型来指代在序列 y = f_θ(x) 上的参数化映射,其中输入和输出 x、y 是 R^D 中长度为 L 的特征向量序列,θ 是通过梯度下降学习的参数。

d239c930d97e2dafb1d7876ab4ae0937

比如RNN,CNN,Transformers都是深度序列模型,深度序列模型面临一系列问题。例如RNN训练速度慢,有梯度消失等问题;CNN专注局部上下文,但是序列推理成本高,上下文长度受限;Transformers序列长度上存在二次拓展问题,是n^2d问题;神经一般微分方程(NODE),理论上可以解决连续时间问题和长期依赖关系,但效率比较低。

所以深度序列模型面临几大挑战:

  1. 通用能力

    RNN:需要快速更新隐藏状态的有状态设置,例如在线处理任务和强化学习;

    CNN:对音频、图像和视频等均匀采样的感知信号进行建模

    Transformers:对语言等领域中密集、复杂的交互进行建模

    NODE:处理非典型世时间序列设置,如缺失或不规则采样数据。

  2. 计算效率

    在训练时,任务一般可以用整个输入序列的损失函数来表述,在推理时,设置可能会发生变化;例如,在在线处理或自回归生成设置中,输入每次只显示一个时间步,模型必须能高效地按顺序处理这些输入。

    RNN本身是序列性的,很难在GPU和TPU等现代硬件加速器上进行训练;另一方面,CNN和Transformers则难以进行高效的自回归推理,因为它们不是有状态的;处理单个新输入的成本可能会与模型的整个上下文大小成比例关系。更奇特的模型可能会带来额外的功能,但通常会使其计算更加困难和缓慢(如需要调用昂贵的微分方程求解器)

  3. 长程依赖

    现实世界中困难可能来自于无法捕捉数据中的交互,比如模型的上下文窗口有限;也可能来自于优化问题,比如梯度消失。

状态空间序列模型(SSM)

SSM定义为一个简单的序列模型,通过一个隐式的潜在状态映射一个1维函数或序列:
20240129160640
20240129160629

SSM 是一种简单而基本的模型,具有许多丰富的特性。它们与 NDE、RNN 和 CNN 等模型族密切相关,实际上可以以多种形式编写,以实现通常需要专门模型才能实现的各种功能(挑战一)。SSM 包含状态方程和观测方程,状态方程是由上一时刻推导,这一步可以与RNN联系;而观测方程描述系统状态如何映射到可观测的数据,这一步类似于CNN结构,有一定的局部捕捉.

  • SSM 是连续的。SSM 本身是一个微分方程。因此,它可以执行连续时间模型的独特应用,如模拟连续过程、处理缺失数据,以及适应不同的采样率。
  • SSM 是循环的。可以使用标准技术将 SSM 离散化为线性 recurrence,并在推理过程中模拟为状态循环模型,每个时间步的内存和计算量保持不变。
  • SSM 是卷积系统。SSM 是线性时不变系统,可显式表示为连续卷积。此外,离散时间版本可以在使用离散卷积进行训练时并行化,从而实现高效训练。

因此,SSM 是一种通用序列模型,在并行和序列环境以及各种领域(如音频、视觉、时间序列)中都能高效运行。

SSM 的通用性也有代价。原始 SSM 仍然面临两个额外挑战 —— 也许比其他模型更严重 —— 这阻碍了它们作为深度序列模型的使用。挑战包括:(1)一般 SSM 比同等大小的 RNN 和 CNN 慢得多;(2)它们在记忆长依赖关系时会很吃力,例如继承了 RNN 的梯度消失问题。

利用结构化 SSM 进行高效计算(S4)

由于状态表示 x ( t ) ∈ R N x (t) ∈ R^N x(t)RN对计算和内存的要求过高(挑战二),通用的 SSM 在实践中无法用作深度序列模型。

对于 SSM 的状态维度 N 和序列长度 L,仅计算完整的潜在状态 x x x 就需要 O ( N 2 L ) O (N^2L) O(N2L)次运算和 O ( N L ) O (NL) O(NL) 的空间 —— 与计算总体输出的 Ω ( L + N ) Ω(L + N) Ω(L+N) 下界相比。因此,对于合理大小的模型(例如 N ≈ 100),SSM 使用的内存要比同等大小的 RNN 或 CNN 多出几个数量级,因此作为通用序列建模解决方案,SSM 在计算上是不切实际的。

S4的前身

Hippo:假设 t 0 t_0 t0时刻我们看到了信号 u ( t ) u(t) u(t)的之前部分:

  1. 我们希望在一个memory budget来压缩前面这一段的input来学习特征,一个很容易想到的方法是用多项式去近似这段input

20240129164121

  1. 在我们接收到更多signal的时候,我们希望仍然在这个memory budget内对整段signal进行压缩,自然,你得更新你的多项式的各项系数,如下图底部所示
    20240129164213

  2. 以上,会涌现出两个问题:

    1. 如何找到这些最优的近似?
    2. 如何快速地更新多项式的参数?

为了解决这两个问题,我们需要一个measure去定义一个近似的好坏程度。例如,可以使用EDM

  1. 这就引出了HiPPO(High-order Polynomial Projection Operator)的正式定义,其为两个信号和两个矩阵的组合:
    20240129164444

  2. HiPPO相当于将函数映射到函数,这里给个通俗的例子解释一下:
    20240129164536
    和上面一样,这里的u是原信号,x是压缩后的信号。给定一个持续增长的u,HiPPO允许online update压缩的x。如果使用一个64unit的polynomial压缩器(完全表示需要10000unit,所以是非常高度的压缩),可以发现EDM很不错,保留了大量之前的信息:
    20240129164640
    其中红色的线相当于对输入的重建(可以看出来,离当下最近时刻的 其刻画最准确,至于离当下最远的时刻 则其刻画的不那么准确 )
    这里要注意,HiPPO只需要看到这个时刻的多项式(polynomial)参数和在此之前的signal u,不需要看到之前的多项式参数.

  3. 上面都是用EDM这个measure的,但是我们在学习过程中用的往往不只一个measure(例如一个time-varying measure can change over time),这个时候如何去建模?
    最终,作者得到了一个结论:HiPPO可以在各种measure上面成立:
    20240129164849

S4的推出

我们正式定义下S4

  1. 首先,有一个state space model,简称为SSM
  2. 其次,在下图所示的两个方程中插入特定的矩阵值
    20240129165007
  3. 学习对应的参数

S4的性质:连续的表示、用Recurrent快速infer、用Convolutional快速训练

20240129165101

  1. 第一个性质是连续的表示,且就算SSM在离散数据上训练,它仍能学习到底层蕴含的连续信息,因为在SSM眼里,sequence不过是连续信号signal的采样(离散形式),或者说连续的信号模型是离散的序列模型的概括
    20240129165139

  2. 第二个性质是有效的online计算,这点之前在HiPPO提到了,就是计算下一时刻的state x ′ x' x 只需要这一时刻的state x x x 和全局输入 u u u
    20240129165225虽然需要全局输入,但是这个全局的计算是常数时间的,这与RNN相同,而与Transformer/CNN不同

    之所以是常数时间,也与RNN相同,因为有state(中间这条蓝线),这导致下一个state的计算只需要上一个state + 全局的输入

  3. SSM的一个问题是,当知道未来的signal的时候,训练是低效的。有没有办法并行化SSM?作者提出了使用一个卷积核 K K K ,绕过状态 x x x ,直接从输入 u u u 到输出 y y y(而非先输入到状态、状态再到输出)
    20240129165725

选择机制的SSM算法(S6)

作者认为,序列建模的一个基础问题是把上下文压缩成更小的状态(We argue that a fundamental problem of sequence modeling is compressing context into a smaller state)

  • 从这个角度来看,注意力机制虽然有效果但效率不算很高,毕竟其需要显式地存储整个上下文(也就是KV缓存),直接导致训练和推理消耗算力大

    For example, attention is both effective and inefficient because it explicitly does not compress context at all. This can be seen from the fact that auto regressive inference requires explicitly storing the entire context (i.e. the KV cache), which directly causes the slow linear-time inference and quadratic-time training of Transformers.

    好比,Transformer就像人类每写一个字之前,都把前面的所有字+输入都复习一遍,所以写的慢

  • RNN的推理和训练效率高,但性能容易受到对上下文压缩程度的限制

    On the other hand, recurrent models are efficient because they have a finite state, implying constant-time inference and linear-time training. However, their effectiveness is limited by how well this state has compressed the context.

    好比,RNN每次只参考前面固定的字数,写的快,但容易忘掉更前面的内容

  • Mamba的解决办法是,让模型对信息有选择性处理,可以关注或忽略特定的内容,即使状态大小固定也能压缩上下文

    好比,Mamba每次参考前面所有内容的一个概括,越往后写对前面内容概括得越狠,丢掉细节、保留大意

总之,序列模型的效率与效果的权衡点在于它们对状态的压缩程度:高效的模型必须有一个小的状态,而有效的模型必须有一个包含来自上下文的所有必要信息的状态,而mamba为了兼顾效率和效果,选择性的关注必须关注的、过滤掉可以忽略的

在其前身结构化状态空间模型S4中,其有4个参数 ( ∆ , A , B , C ) (∆, A, B, C) (,A,B,C)

20240129170425

且S4是LTI系统,不随输入变化,这些参数控制了以下两个阶段:
20240129170511

  • 第一阶段(1a 1b),通常采用固定公式A = 𝑓𝐴(∆, A)和B = 𝑓𝐵(∆, A, B),将“连续参数”(∆,A,B)转化为“离散参数”(A,B),其中(𝑓𝐴, 𝑓𝐵) 称为离散化规则,且可以使用多种规则来实现这一转换,例如下述方程中定义的零阶保持(ZOH)
    20240129170657

  • 第二阶段(2a 2b,和3a 3b),在参数由(∆,A, B, C)变换为(A, B, C)后,模型可以用两种方式计算,即线性递归(2)或全局卷积(3)。

  • 模型使用卷积模式(3)可以进行高效的并行化训练(其中整个输入序列提前看到),并切换到循环模式(2)以高效的自回归推理(其中输入每次只看到一个时间步)

  • 为何可以做高效的并行化呢,因为该模式能够绕过状态计算,并实现仅包含(B, L, D)的卷积核(3a)

20240129171041矩阵都可以由N个数字表示,为了对批量大小为 B B B、长度为 L L L、具有 D D D个通道的输入序列 x x x进行操作,SSM被独立地应用于每个通道

请注意,在这种情况下,每个输入的总隐状态具有 D N DN DN维,在序列长度上计算它需要 O ( B L D N ) O(BLDN) O(BLDN)时间和内存

各个变量含义:

  • Δ,一个标量,类似遗忘门,
    这个量跟RNN里的gating有着深刻的联系,data dependent的 Δ 跟RNN的forget gate的功能类似

  • B B B,起到的作用类似于:进RNN的memory

  • C C C,起到的作用类似于:取RNN的memory
    所以有人说,data dependent的 B / C B/C B/C的功能跟RNN的input/output gate类似

  • A A A,意味着对应这个维度的SSM来说,A在每个hidden state维度上的作用可以不相同,起到multi-scale/fine-grained gating的作用,这也是LSTM网络里面用element-wise product的原因

而在Mamaba中,作者让这些参数 B 、 C 、 Δ B、C、\Delta BCΔ成为输入的函数,让模型能够根据输入内容自适应地调整其行为

20240129171428

  1. 从S4到S6的过程中,可以看出 B B B C C C的大小从原来的 ( D , N ) (D,N) (D,N)变成了 ( B , L , N ) (B,L,N) (B,L,N), Δ \Delta Δ的大小由原来的 D D D变成了 ( B , L , D ) (B,L,D) (B,L,D)
    进一步,咱们通过
    20240129171634
    20240129171644
    20240129171701
    20240129171707
    来逐一将 B , C , Δ B,C,\Delta B,C,Δ数据依赖(data dependent)化
    至于上面的所谓20240129171758代表把维的输入向量 x x x经过一个线性层map到d维
    N N N即SSM的隐藏层维度(hidden dimension),当然 一般设的比较小
    且每个位置的 B 、 C 、 Δ B、C、\Delta BCΔ都不相同(S4时是所有位置共享)

  2. 虽然A没有变成data dependent,但是通过state space model的离散化操作之后, ( A ˉ , B ˉ ) (\bar A,\bar B) (Aˉ,Bˉ)

  3. 会经过outer product变成 ( B , L , N , D ) (B,L,N,D) (B,L,N,D)的data dependent张量,以一种parameter efficient的方式来达到data dependent的目的

神经常微分方程(NODE)

ResNet

神经常微分方程的动机来自于ResNet。ResNet可以抽象表示为一个表示某些非线性函数、权重矩阵、偏置和残差连接的函数。
h t + 1 = f ( h t , θ t ) + h t h_{t+1}=f(h_t,θ_t)+h_t ht+1=f(ht,θt)+ht

将网络层之间的间隔推向一个无限小的值,我们可以将ResNet转换成一个连续的神经网络,这也正是神经ODE所要做的。通过这样做,我们可以将ResNet的离散层与其连续神经网络表示进行比较。我们可以看到,连续神经网络中潜在状态的变化速率由一些非线性函数决定,而这个函数在时间上不会发生变化,这很像ODE的形式。
h t + 1 = f ( h t , θ t ) + h t h_{t+1}=f(h_t,θ_t)+h_t ht+1=f(ht,θt)+ht
→ h t + 1 − h t = f ( h t , θ t ) \rightarrow h_{t+1}-h_t=f(h_t,θ_t) ht+1ht=f(ht,θt)
→ h t + 1 − h t 1 = f ( h t , θ t ) \rightarrow \frac{h_{t+1}-h_t}{1}=f(h_t,θ_t) 1ht+1ht=f(ht,θt)
→ h t + ∇ − h t ∇ ∣ ∇ = 1 = f ( h t , θ t ) \rightarrow \frac {h_{t+\nabla}-h_t}{\nabla}|_{\nabla=1}=f(h_t,θ_t) ht+ht=1=f(ht,θt)
→ lim ⁡ ∇ → 0 h t + ∇ − h t ∇ ∣ ∇ = 1 = f ( h t , θ t ) \rightarrow\lim\limits_{\nabla\rightarrow0}\frac {h_{t+\nabla}-h_t}{\nabla}|_{\nabla=1}=f(h_t,θ_t) 0limht+ht=1=f(ht,θt)
→ d h ( t ) d t = f ( h ( t ) , t , θ ) \rightarrow \frac{dh(t)}{dt}=f(h(t),t,θ) dtdh(t)=f(h(t),t,θ)

ResNet:离散和连续

比较离散和连续的ResNet:

离散:

h t + 1 = f ( h t , θ t ) + h t h_{t+1}=f(h_t,θ_t)+h_t ht+1=f(ht,θt)+ht

  • L L L个离散层
  • 潜在状态以离散方式改变
  • 潜在状态动态由 L L L个函数控制

连续:
d h ( t ) d t = f ( h ( t ) , t , θ ) \frac{dh(t)}{dt}=f(h(t),t,θ) dtdh(t)=f(h(t),t,θ)

  • 无限层
  • 潜在状态以连续方式改变
  • 潜在状态动态由一个函数控制

RNNs

想通的逻辑可以应用于RNN。RNN中,下一个时间步的隐藏层由一个公式确定,该公式涉及一个非线性激活函数(如tanh),当前和前一个隐藏层的线性变换,以及当前输入和偏差向量的线性变换。假设线性变换的前一个隐藏层是恒等函数,激活函数是向量1,我们可以将该公式重写为微分方程。这使我们能够使用神经网络处理时间序列和序列数据。但是,在处理连续时间时,ODE比RNN更好,因为RNN本质上是离散序列处理器。
f09d32b819e9e2aabcd5343d67ff2326

神经ODE

在神经常微分方程中,隐藏层的导数函数被建模为一个神经网络。与传统微分方程不同的是,神经常微分方程的导数是由神经网络建模的。 这意味着对于像图像分类或时间序列分类这样的IVP问题,我们可以通过神经网络和反向传播从数据中学习导数函数。目标是使用神经网络建模导数函数并以成功完成任务的方式进行学习。即:

  • 我们有一个ODE问题。
  • 我们不知道 d h ( t ) d t = f ( h ( t ) , t , θ ) \frac{dh(t)}{dt}=f(h(t),t,θ) dtdh(t)=f(h(t),t,θ),其中 f ( h ( t ) , t , θ ) f(h(t),t,θ) f(h(t),t,θ)是一个神经网络。
  • 我们希望通过神经网络+反向传播从数据中学习 d h ( t ) d t \frac{dh(t)}{dt} dtdh(t)

神经ODE:前向传播

我们现在讨论神经常微分方程的前向传播。输入状态为
h 0 h_0 h0,表示初始时间步。

状态动态或导数函数表示为神经网络,通常是一个具有一到两个隐藏层的前馈神经网络。即:

状态动态: d h ( t ) d t = f ( h ( t ) , t , θ ) \frac{dh(t)}{dt}=f(h(t),t,θ) dtdh(t)=f(h(t),t,θ)

输出状态表示为在时间上积分动态函数使用常微分方程数值解器。求解器的参数包括初始状态、动态函数、动态函数参数、初始时间、终止时间和 ∇ t \nabla_t t ,它是固定步长求解器所必需的。自适应步长求解器不需要 ∇ t \nabla_t t。我们还需要训练或更新神经常微分方程模型的参数。即:
5f8a3f61fcff0af4dfba54a48988b347

ODEnet 在前向传播中有两个核心特点:

  1. 模型的深度:
    由于 ODEnet 是连续的,它没有明确的层级数。在这篇论文中,作者使用 ODE Solver 评估的次数来代表模型的“深度”。

  2. 深度与误差控制的关系
    ODEnet 的“深度”与误差容忍度直接相关。更低的误差容忍度会增加 ODE Solver 的评估次数,从而增加模型的“深度”。这为我们提供了一个工具,使我们能够在准确性和计算成本之间权衡。

20240129205729

论文中的图示阐明了以下点:

  • a 图:随着误差容忍度的降低,前向传播的函数评估数增加。
  • b 图:展示了评估数与计算时间的关系。
  • c 图:显示前向传播的函数评估数约为反向传播评估数的两倍。这证明了 adjoint sensitivity 方法在内存和计算效率上的优势,因为它不需要逐个评估前向传播的每个步骤。
  • d 图:随着训练的进行,函数的评估数逐渐增加,说明模型的复杂度也随之上升。

总之,ODEnet 提供了一种在准确性与计算成本之间权衡的方法,并且具有许多有趣的性质和效率优势。

反向传播

关于反向传播的核心挑战在于如何将梯度传递通过 ODE Solver。 一个直接的方法是让梯度按照前向传播的计算路径返回,但这种做法在内存和数值误差方面都不太理想。因此,作者采取了一种策略,将前向的 ODE Solver 视为一个黑箱,不需要(或很难)传入梯度,而是采用另一种方法“绕过”。

这种“绕过”策略称为 adjoint method。在反向传播时,模型通过一个增广的 ODE Solver 来计算梯度。 这种方法计算和内存效率高,而且能准确控制数值误差。

具体而言,若我们的损失函数为 L(),且它的输入为 ODE Solver 的输出:

20240129210130

我们第一步需要求 L 对 z(t) 的导数,或者说模型损失的变化如何取决于隐藏状态 z(t) 的变化。其中损失函数 L 对 z(t_1) 的导数可以为整个模型的梯度计算提供入口。作者将这一个导数称为 adjoint a(t) = -dL/z(t),它其实就相当于隐藏层的梯度。

在基于链式法则的传统反向传播中,我们需要从后一层对前一层求导以传递梯度。而在连续化的 ODEnet 中,我们需要将前面求出的 a(t) 对连续的 t 进行求导,由于 a(t) 是损失 L 对隐藏状态 z(t) 的导数,这就和传统链式法则中的传播概念基本一致。下式展示了 a(t) 的导数,它能将梯度沿着连续的 t 向前传。

20240129210256

在获取每一个隐藏状态的梯度后,我们可以再求它们对参数的导数,并更新参数。同样在 ODEnet 中,获取隐藏状态的梯度后,再对参数求导并积分后就能得到损失对参数的导数,这里之所以需要求积分是因为「层级」t 是连续的。这一个方程式可以表示为:

20240129210355

综上,具体过程如下:

  1. 找到梯度的入口点:对于损失函数 L() 和它的输入(即 ODE Solver 的输出)z(t),我们首先计算损失 L 对 z(t) 的导数,这为整个梯度计算提供了一个起点。这个导数称为 adjoint a(t),相当于隐藏层的梯度。
  2. 计算时间上的梯度变化:在连续时间模型 ODEnet 中,梯度沿时间连续传播。a(t) 的导数描述了这种梯度如何沿时间变化。
  3. 求参数的导数:在计算出每个隐藏状态的梯度后,我们再对参数求导。因为时间“层级”是连续的,所以需要进行积分操作来得到损失对参数的导数。

网络对比

1. ResNet

  • 主要组件:
    • f:代表卷积层,其中h是上一层的输出特征图,t是当前的卷积层序号。
    • ResNet:定义了整个模型的架构,由T个残差模块构成
  • 伪代码:
     def f(h, t, θ):
      return nnet(h, θ_t)
     
     def resnet(h):
      for t in [1:T]:
          h = h + f(h, t, θ)
      return h
    

2.ODEnet

  • 主要组件:
    • f:定义为神经网络,这里的θ作为一个整体参数,t作为独立的输入参数也被送入网络。这种设计意味着网络层次是连续的。
    • ODEnet:不需要离散层的循环构建,仅需使用ODE solver来找出在t_1时刻的h
  • 伪代码:
    def f(h, t, θ):
      return nnet([h, t], θ)
    
    def ODEnet(h, θ):
      return ODESolver(f, h, t_0, t_1, θ) 
    

为了进一步比较这两种网络,陈天琦及其团队在 MNIST 数据集上进行了实验。他们对比了一个包含6个残差模块的 ResNet 和一个使用 ODE Solver 替代残差模块的 ODEnet。实验结果展示了这两种网络在 MNIST 上的性能、参数量、内存使用情况和计算复杂度.
20240129205510

NODE对ResNet的优点

  1. 内存优化
  2. 自适应计算,通过显式地改变数值积分的精度,可以自由地权衡模型的速度和精度
  3. 连续标准化流
  4. 时间序列中的不规则采样
    20240129211011
  • 48
    点赞
  • 51
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值