【机器学习】图神经网络(NRI)模型原理和运动轨迹预测代码实现

1.引言

1.1.NRI研究的意义

在许多领域,如物理学、生物学和体育,我们遇到的系统都是由相互作用的组分构成的,这些组分在个体和整体层面上都产生复杂的动态。建模这些动态是一个重大的挑战,因为往往我们只能获取到个体的轨迹数据,而不知道其背后的相互作用机制或具体的动态模型。

以篮球运动员在球场上的运动为例,运动员的动态显然受到其他运动员的影响。作为观察者,我们能够推断出场上可能发生的各种交互,如防守、掩护等。然而,手动标注这些交互不仅繁琐,而且耗时。因此,一个更有前景的方法是在无监督的条件下学习这些底层的交互模式,这些模式可能在多种不同的任务中都具有通用性。
在这里插入图片描述

1.2.主要内容

本文将介绍一种基于图结构潜在空间的变分自编码器模型——神经关系推理(Neural Relational Inference)模型。这种模型在相关论文中被详细阐述,并配备了代码库以便于实现和实验。

神经关系推理模型旨在解决预测粒子运动轨迹的问题,特别是在存在未知粒子间相互作用的情况下。设想我们有一组粒子(例如带电粒子),它们因某种相互作用(如电磁力)而在空间中移动。我们观察到每个粒子在一段时间T内的运动轨迹,包括其位置和速度。每个粒子的新状态不仅由其当前状态决定,还受到其他粒子的影响。我们拥有一组粒子的轨迹数据,但粒子间的确切相互作用未知。

模型的目标是通过学习粒子的动态行为,基于已知的轨迹样本来预测未来的轨迹。如果已知粒子间的相互作用形式(即它们如何以图的形式相互连接),预测粒子的动态将更为直接。在这种理想情况下,每个粒子对应于图中的一个节点,而节点间的连接强度可以通过边的权重来表示。然而,在这个问题中,我们并没有获得这样的交互图。

因此,神经关系推理模型采用了变分自编码器的编码器部分,以从给定的轨迹数据中采样潜在的交互图。具体来说,编码器部分使用图神经网络(GNN)技术来捕捉粒子间的潜在关系,并生成一个能够代表这些关系的图结构。这个图结构随后被用作解码器部分的输入,以预测粒子的未来轨迹。

通过这种方式,神经关系推理模型能够同时学习粒子间的潜在交互和粒子的动态行为,从而实现更准确的轨迹预测。在训练过程中,模型通过最大化给定轨迹数据下的似然函数(即证据下界ELBO)来优化其参数,以使得生成的潜在交互图能够最好地解释和预测观察到的轨迹。

2.神经关系推理模型(NRI)原理

2.1.NRI的基本原理

神经关系推理(NRI)模型是一个专注于从观察到的轨迹中推断对象间相互作用和动态行为的模型。它由两个核心组件组成:编码器和解码器,这两个组件是联合训练的。

2.1.1.编码器

编码器负责根据观察到的轨迹数据 x x x 来预测对象间的相互作用,即潜在的图结构 z z z。这里的轨迹数据 x x x 包括 N 个对象在 T 个时间步上的特征向量集合,具体地,我们用 x i t x^t_i xit 表示第 t 个时间点上对象 v i v_i vi 的特征向量(如位置和速度)。所有对象在时间点 t 的特征集合记作 x t = ( x 1 t , … , x N t ) x^t = (x^t_1, \ldots, x^t_N) xt=(x1t,,xNt),而对象 v i v_i vi 的完整轨迹是 x i = ( x i 1 , … , x i T ) x_i = (x^1_i, \ldots, x^T_i) xi=(xi1,,xiT)

编码器 q ( z ∣ x ) q(z|x) q(zx) 的目标是输出一个分布,该分布描述了给定轨迹 x x x 下潜在图结构 z z z 的可能性。特别是, z i j z_{ij} zij 表示对象 v i v_i vi v j v_j vj 之间的离散边类型,用于表示它们之间的交互类型。编码器使用 K 种可能的交互类型对 z i j z_{ij} zij 进行一位有效编码。

2.1.2.解码器

解码器则基于编码器输出的潜在图结构 z z z 和已知的轨迹数据 x x x 来学习并预测对象的动态行为。具体来说,解码器通过以下公式定义:
p ( x ∣ z ) = ∏ t = 1 T p ( x t + 1 ∣ x t , x 1 , z ) p(x|z) = \prod_{t=1}^{T} p(x_{t+1}|x^t, x^1, z) p(xz)=t=1Tp(xt+1xt,x1,z)
它使用图神经网络(GNN)来模拟给定潜在图 z z z 和历史轨迹 x t x^t xt 下,下一个时间步 t + 1 t+1 t+1 的轨迹 x t + 1 x_{t+1} xt+1 的分布。

2.1.3.模型优化

整个模型基于变分自编码器(VAE)框架进行优化,目标是最大化证据下界(ELBO):
L = E q ( z ∣ x ) [ log ⁡ p ( x ∣ z ) ] − KL [ q ( z ∣ x ) ∣ ∣ p ( z ) ] L = \mathbb{E}_{q(z|x)}[\log p(x|z)] - \text{KL}[q(z|x) || p(z)] L=Eq(zx)[logp(xz)]KL[q(zx)p(z)]
其中,KL 表示 Kullback-Leibler 散度,用于衡量两个概率分布之间的差异。先验 p ( z ) p(z) p(z) 假设边类型是均匀分布的,但也可以根据需要采用其他先验分布,比如鼓励稀疏图的先验。

NRI 模型通过编码器和解码器的联合训练,能够无监督地从观察到的轨迹中学习对象间的相互作用和动态行为,这对于理解复杂系统的运动规律和交互模式具有重要意义。

2.2.与VAE编码器的差异

与原始的变分自编码器(VAE)模型相比,我们的神经关系推理(NRI)模型确实存在几个显著的不同之处。以下是对这些差异的详细改写和描述:

  1. 多时间步预测
    在原始的VAE中,解码器通常被训练来根据潜在变量(z)重构单个数据点。然而,在我们的NRI模型中,为了捕捉系统动态的连续性和交互的长期影响,我们训练解码器来预测多个时间步的轨迹,而不仅仅是单个时间步。这种设置使得解码器在预测过程中能够充分利用潜在交互图(z)中的信息,从而避免了解码器忽略潜在变量(z)的问题。

  2. 离散潜在变量与连续松弛
    原始的VAE通常处理连续的潜在变量,而我们的NRI模型则使用离散的潜在变量(z)来表示对象之间的交互类型。为了能够在反向传播过程中优化离散的潜在变量,我们采用了连续的松弛方法,如Gumbel-Softmax或Straight-Through(ST)估计器,以便能够利用重参数化技巧进行梯度传播。这种方法允许我们在保持潜在变量离散性的同时,有效地优化模型参数。

  3. 未建模的初始状态
    在原始的VAE中,通常会对整个数据序列(包括初始状态)进行建模。然而,在我们的NRI模型中,我们主要关注于对象之间的动态交互和这些交互如何影响对象的轨迹。因此,我们没有对初始状态的概率(p(x^1))进行显式建模。尽管如此,如果需要,我们可以轻松地扩展模型以包含对初始状态的建模,但这通常不会显著影响模型在动态和交互预测方面的性能。

  4. 图神经网络(GNN)的引入
    除了上述差异外,我们的NRI模型还引入了图神经网络(GNN)来捕捉对象之间的交互。GNN能够处理图结构的数据,并通过在节点和边之间传递信息来更新节点的表示。在我们的模型中,GNN被用作解码器的一部分,它根据潜在交互图(z)和历史轨迹来预测未来的轨迹。这种图结构的数据处理方法使得我们的模型能够更好地捕捉对象之间的复杂交互和依赖关系。

NRI模型通过引入多时间步预测、离散潜在变量与连续松弛、以及图神经网络等方法,在保持原始VAE框架灵活性的同时,针对动态系统和交互推理问题进行了有效的改进和优化。

模型的概览图如图 1 所示。接下来,我们将详细介绍模型的编码器和解码器部分。
在这里插入图片描述图 1. NRI模型由两个共同训练的部分构成:一个编码器,它根据输入轨迹预测潜在交互的概率分布 q ( z ∣ x ) q(z|x) q(zx);以及一个解码器,它根据编码器的潜在编码和轨迹的前一时间步生成轨迹预测。编码器采用具有多轮节点到边(v → e)和边到节点(e → v)消息传递的GNN形式,而解码器则并行运行多个GNN,每个GNN对应编码器潜在编码 q ( z ∣ x ) q(z|x) q(zx)提供的一种边类型。(图引用自论文:Neural Relational Inference for Interacting Systems

2.3.编码器

编码器在NRI模型中的核心任务是,在观察到轨迹数据 x = ( x 1 , … , x T ) x = (x_1, \ldots, x_T) x=(x1,,xT) 的基础上,推断出对象间潜在的成对交互类型 z i j z_{ij} zij。由于真实世界的图结构通常是未知的,我们利用一个在全连接图(即每对对象之间都存在潜在的边)上运作的图神经网络(GNN)来预测这种潜在的图结构。

具体地,我们构建编码器模型如下:

q ( z i j ∣ x ) = softmax ( f enc ( x ) i j 1 : K ) q(z_{ij}|x) = \text{softmax}(f_{\text{enc}}(x)_{ij}^{1:K}) q(zijx)=softmax(fenc(x)ij1:K)

其中, f enc ( x ) f_{\text{enc}}(x) fenc(x) 是我们的编码器函数,它在一个不包含自环的全连接图上应用GNN操作。给定输入轨迹 x 1 , … , x T x_1, \ldots, x_T x1,,xT,编码器执行以下消息传递操作来逐步构建对象的表示和边的嵌入:

  1. 初始化节点嵌入:
    h j 1 = f emb ( x j ) h^1_j = f_{\text{emb}}(x_j) hj1=femb(xj)
    这里, f emb f_{\text{emb}} femb 是一个嵌入函数,它将原始轨迹数据 ( x_j ) 映射到初始的节点表示 h j 1 h^1_j hj1

  2. 边嵌入的第一层更新:
    h ( i j ) 1 = f e 1 ( [ h i 1 , h j 1 ] ) h^1_{(ij)} = f^1_e([h^1_i, h^1_j]) h(ij)1=fe1([hi1,hj1])
    对于每对节点 ( i , j ) (i, j) (i,j) f e 1 f^1_e fe1 是一个边更新函数,它接受两个相邻节点的嵌入 h i 1 h^1_i hi1 h j 1 h^1_j hj1,并输出一个更新的边嵌入 h ( i j ) 1 h^1_{(ij)} h(ij)1

  3. 节点嵌入的第二层更新:
    h j 2 = f v 1 ( ∑ i ≠ j h ( i j ) 1 ) h^2_j = f^1_v\left(\sum_{i \neq j} h^1_{(ij)}\right) hj2=fv1i=jh(ij)1
    这里, f v 1 f^1_v fv1 是一个节点更新函数,它聚合所有指向节点 j j j 的边嵌入 h ( i j ) 1 h^1_{(ij)} h(ij)1,并据此更新节点 j j j 的嵌入 h j 2 h^2_j hj2

  4. 边嵌入的第二层更新(可选):
    h ( i j ) 2 = f e 2 ( [ h i 2 , h j 2 ] ) h^2_{(ij)} = f^2_e([h^2_i, h^2_j]) h(ij)2=fe2([hi2,hj2])
    在某些情况下,为了进一步增强边嵌入的表达能力,可以执行第二次边嵌入更新。

最终,我们使用更新后的边嵌入 h ( i j ) 2 h^2_{(ij)} h(ij)2(或 h ( i j ) 1 h^1_{(ij)} h(ij)1 如果不进行第二次更新)来建模边类型的后验概率分布:

q ( z i j ∣ x ) = softmax ( h ( i j ) 2 ) q(z_{ij}|x) = \text{softmax}(h^2_{(ij)}) q(zijx)=softmax(h(ij)2)

这里, θ \theta θ 代表了方程中涉及的所有神经网络参数。通过采用多次消息传递(在我们的例子中为两次),模型能够“分离”多个交互,即使它仅依赖于二元项(即仅考虑两个对象之间的交互)。在单次传递中,边嵌入 h ( i j ) 1 h^1_{(ij)} h(ij)1 主要基于 x i x_i xi x j x_j xj 的信息,而在第二次传递中, h j 2 h^2_j hj2 则利用了整个图的信息。

函数 f f f 是神经网络,用于在节点和边的表示之间进行映射。在我们的实验中,我们使用了多种神经网络架构,如全连接网络(MLP)或带有注意力池化的1D卷积网络(CNN),作为 f f f 函数的实现。这些选择允许我们根据任务和数据特性灵活地调整模型结构。

值得注意的是,与传统的GNN不同,在我们的NRI模型中,边嵌入 h ( i j ) l h^l_{(ij)} h(ij)l 不再仅仅被视为计算过程中的一个瞬时部分,而是模型的一个关键组成部分,它直接用于执行边分类(即预测对象间的交互类型)。

2.4.采样理论

在NRI模型中,从编码器输出的离散分布 q ( z i j ∣ x ) q(z_{ij}|x) q(zijx) 中采样交互类型 z i j z_{ij} zij 是必要的步骤。然而,由于这些潜在变量是离散的,我们不能直接使用基于梯度的优化方法(如反向传播)来通过采样过程进行训练,因为采样操作本身是不可导的。

为了解决这个问题,我们采用了具体分布(也称为Gumbel-Softmax分布,由Maddison等人,2017和Jang等人,2017提出)的方法,它允许我们从离散分布的连续近似中采样,并使用重参数化技巧来估计梯度。

具体地,我们采用Gumbel-Softmax技巧来从 q ( z i j ∣ x ) q(z_{ij}|x) q(zijx) 中抽取样本 z i j z_{ij} zij,这通过以下方式实现:

z i j = one_hot ( argmax k ( [ h ( i j ) 2 + g ] k ) ) z_{ij} = \text{one\_hot}(\text{argmax}_k([h^2_{(ij)} + g]_{k})) zij=one_hot(argmaxk([h(ij)2+g]k))

但是,由于直接使用 argmax 仍然不可导,我们引入了一个温度参数 τ \tau τ(softmax温度)来“软化”这个选择过程,得到一个可导的近似。于是,我们实际上使用:

y i j = exp ⁡ ( ( h ( i j ) 2 + g ) / τ ) ∑ k exp ⁡ ( ( h ( i j ) k 2 + g k ) / τ ) y_{ij} = \frac{\exp((h^2_{(ij)} + g)/\tau)}{\sum_k \exp((h^2_{(ij)k} + g_k)/\tau)} yij=kexp((h(ij)k2+gk)/τ)exp((h(ij)2+g)/τ)

其中 g g g 是从Gumbel(0,1)分布中独立抽取的与 h ( i j ) 2 h^2_{(ij)} h(ij)2 形状相同的噪声向量。向量 y i j y_{ij} yij 现在是 h ( i j ) 2 h^2_{(ij)} h(ij)2 的一个连续且可导的近似,并且随着 τ \tau τ 趋近于0,它将趋近于一个one-hot编码的样本,即 z i j z_{ij} zij 的一个有效样本。

在训练过程中,我们使用 y i j y_{ij} yij 作为 z i j z_{ij} zij 的近似,并允许梯度通过 y i j y_{ij} yij 反向传播到编码器的参数中。这样,我们就能在保持模型可导性的同时,模拟从离散分布中采样的过程。在测试或部署模型时,我们通常将 τ \tau τ 设置为一个非常小的值(接近于0),以得到接近真实离散样本的结果。

2.5.解码器

解码器的核心功能是预测基于当前交互系统动态的未来状态延续。由于解码器的预测依赖于潜在的图结构 z z z,我们可以利用图神经网络(GNN)来模拟这种依赖关系。

对于物理模拟任务,特别是当状态由位置和速度等物理量组成时,如果 z z z 代表了真实的图结构,那么系统动态通常是马尔可夫的,即 p ( x t + 1 ∣ x t , x 1 , z ) = p ( x t + 1 ∣ x t , z ) p(x_{t+1}|x^t, x^1, z) = p(x_{t+1}|x^t, z) p(xt+1xt,x1,z)=p(xt+1xt,z)。基于这个假设,我们设计了一个类似于交互网络的GNN作为解码器。与一般的交互网络不同,我们的解码器为每种边类型都配备了单独的神经网络。

更具体地说,解码器的工作流程如下:

  1. 边嵌入计算:对于每对节点 i i i j j j,我们根据边类型 z z z 和当前状态 x t x^t xt 计算边嵌入 h ( i j ) t h^t_{(ij)} h(ij)t。这通过以下方式实现:

    h ( i j ) t = ∑ k z i j k f e k ( [ x i t , x j t ] ) h^t_{(ij)} = \sum_{k} z_{ijk} f^k_e([x^t_i, x^t_j]) h(ij)t=kzijkfek([xit,xjt])

    其中, z i j k z_{ijk} zijk 表示边类型向量 z i j z_{ij} zij 的第 k k k 个元素, f e k f^k_e fek 是对应于边类型 k k k 的神经网络。当 z i j k z_{ijk} zijk 是离散的一位有效样本时, h ( i j ) t h^t_{(ij)} h(ij)t 仅由对应的 f e k ( [ x i t , x j t ] ) f^k_e([x^t_i, x^t_j]) fek([xit,xjt]) 确定;而对于连续松弛的情况,我们得到的是所有可能边类型的加权和。

  2. 节点更新:然后,我们使用聚合的边嵌入来更新每个节点的状态。这通过以下公式完成:

    x t + 1 j = x j t + f v ( ∑ i ≠ j h ( i j ) t ) x_{t+1}^j = x^t_j + f_v(\sum_{i \neq j} h^t_{(ij)}) xt+1j=xjt+fv(i=jh(ij)t)

    其中, f v f_v fv 是一个节点更新函数,它根据所有指向节点 j j j 的边嵌入来计算节点 j j j 在下一个时间步的状态变化。注意,我们在更新时添加了当前状态 x j t x^t_j xjt,这样模型实际上学习的是状态变化 Δ x j t \Delta x^t_j Δxjt

  3. 输出分布:最后,我们假设下一个时间步的状态 x t + 1 j x_{t+1}^j xt+1j 服从以 μ t + 1 j \mu_{t+1}^j μt+1j 为均值、 σ 2 I \sigma^2 I σ2I 为协方差矩阵的正态分布,即

    p ( x t + 1 j ∣ x t , z ) = N ( μ t + 1 j , σ 2 I ) p(x_{t+1}^j|x^t, z) = \mathcal{N}(\mu_{t+1}^j, \sigma^2 I) p(xt+1jxt,z)=N(μt+1j,σ2I)

    其中, μ t + 1 j \mu_{t+1}^j μt+1j 通常可以设置为 x t + 1 j x_{t+1}^j xt+1j(或经过某些变换的 x t + 1 j x_{t+1}^j xt+1j), σ 2 \sigma^2 σ2 是一个固定的方差项。这样的假设允许我们利用概率模型来捕捉状态预测的不确定性。

2.6.解码器退化问题优化

在优化基于证据下界(ELBO)的目标函数时,特别是重构损失项 ∑ t = 1 T log ⁡ [ p ( x t ∣ x t − 1 , z ) ] \sum_{t=1}^T \log[p(x_t | x_{t-1}, z)] t=1Tlog[p(xtxt1,z)],我们面临的一个挑战是交互可能对短期动态的影响很小。这可能导致解码器在优化过程中忽略潜在的边信息,仅仅实现一个基于短期预测的次优模型。

为了解决这个问题,我们采用了两种策略。首先,我们预测未来多步的状态,因为在一个较长的预测序列中,一个“退化”的解码器(即忽略潜在边信息的解码器)的表现会显著下降。其次,我们为每种边类型设计了一个单独的多层感知机(MLP),而不是使用一个统一的神经网络来处理所有边类型。这种设计使得边类型对解码器的影响更加明确,从而更难被模型忽略。

具体来说,我们采用了一个滚动预测的策略来预测未来多步的状态。在这个过程中,我们首先用初始的真实状态 x j 1 x^1_j xj1 来启动预测,并计算预测状态的均值 μ j t \mu^t_j μjt。然后,我们将这些预测均值作为下一时间步的输入,迭代地进行多步预测。在达到某个预定的步数(如 M = 10 M=10 M=10)后,我们再次使用真实的状态 x j M + 1 x^{M+1}_j xjM+1 来纠正预测,并继续进行多步预测。这个过程可以表示为:

μ j 2 = f dec ( x j 1 ) \mu^2_j = f_{\text{dec}}(x^1_j) μj2=fdec(xj1)
μ t + 1 j = f dec ( μ j t ) for  t = 2 … M \mu_{t+1}^j = f_{\text{dec}}(\mu^t_j) \quad \text{for } t = 2 \ldots M μt+1j=fdec(μjt)for t=2M
μ M + 2 j = f dec ( x j M + 1 ) \mu_{M+2}^j = f_{\text{dec}}(x^{M+1}_j) μM+2j=fdec(xjM+1)
μ t + 1 j = f dec ( μ j t ) for  t = M + 2 … 2 M \mu_{t+1}^j = f_{\text{dec}}(\mu^t_j) \quad \text{for } t = M+2 \ldots 2M μt+1j=fdec(μjt)for t=M+22M

在这个多步预测的过程中,我们通过反向传播来优化整个序列的预测误差。由于错误会在多步预测中累积,因此一个忽略潜在边信息的退化解码器将会表现得非常糟糕,从而促使模型更好地利用这些边信息来改进预测。这种策略有效地避免了退化解码器的出现,并提高了模型的整体性能。

2.7.递归解码器

在处理那些不满足马尔可夫假设的应用时,我们需要一个更复杂的解码器模型来捕捉时间上的依赖关系。为此,我们引入了一个递归解码器,它使用门控循环单元(GRU)(Cho 等人,2014)来建模 p ( x t + 1 ∣ x t , x 1 , z ) p(x_{t+1}|x^t, x^1, z) p(xt+1xt,x1,z)。递归解码器在GNN的消息传递操作中集成了GRU单元,以捕捉时间上的动态信息。

具体地说,递归解码器的工作流程如下:

  1. 边嵌入计算:对于每对节点 i i i j j j,我们根据边类型 z z z 和前一时间步的隐藏状态 h t h^t ht 计算边嵌入 h ( i j ) t h^t_{(ij)} h(ij)t

    h ( i j ) t = ∑ k z i j k f e k ( [ h i t , h j t ] ) h^t_{(ij)} = \sum_{k} z_{ijk} f^k_e([h^t_i, h^t_j]) h(ij)t=kzijkfek([hit,hjt])

    其中, z i j k z_{ijk} zijk 是边类型向量 z i j z_{ij} zij 的第 k k k 个元素, f e k f^k_e fek 是对应于边类型 k k k 的神经网络。

  2. 节点消息聚合:然后,我们聚合所有指向节点 j j j 的边嵌入,形成消息 MSG j t \text{MSG}^t_j MSGjt

    MSG j t = ∑ i ≠ j h ( i j ) t \text{MSG}^t_j = \sum_{i \neq j} h^t_{(ij)} MSGjt=i=jh(ij)t

  3. GRU更新:我们使用GRU单元来更新节点 j j j 的隐藏状态。GRU的输入是聚合消息 MSG j t \text{MSG}^t_j MSGjt、当前输入 x j t x^t_j xjt 和前一隐藏状态 h j t h^t_j hjt 的组合:

    h t + 1 j = GRU ( [ MSG j t , x j t , h j t ] ) h_{t+1}^j = \text{GRU}([\text{MSG}^t_j, x^t_j, h^t_j]) ht+1j=GRU([MSGjt,xjt,hjt])

  4. 输出预测:我们通过一个输出转换函数 f out f_{\text{out}} fout(由一个小的多层感知机MLP建模)来从新的隐藏状态 h t + 1 j h_{t+1}^j ht+1j 预测下一个时间步的状态 x t + 1 j x_{t+1}^j xt+1j 的均值 μ t + 1 j \mu_{t+1}^j μt+1j

    μ t + 1 j = x j t + f out ( h t + 1 j ) \mu_{t+1}^j = x^t_j + f_{\text{out}}(h_{t+1}^j) μt+1j=xjt+fout(ht+1j)

  5. 输出分布:最后,我们假设下一个时间步的状态 x t + 1 j x_{t+1}^j xt+1j 服从以 μ t + 1 j \mu_{t+1}^j μt+1j 为均值、 σ 2 I \sigma^2 I σ2I 为协方差矩阵的正态分布:

    p ( x t + 1 j ∣ x t , z ) = N ( μ t + 1 j , σ 2 I ) p(x_{t+1}^j | x^t, z) = \mathcal{N}(\mu_{t+1}^j, \sigma^2 I) p(xt+1jxt,z)=N(μt+1j,σ2I)

在预测多步时间时,为了避免在预测路径上产生人为现象,我们采用了一个混合策略。在前 T − M T - M TM 步中,我们使用真实的输入 x j t x^t_j xjt 作为输入。然后,在最后 M M M 步中,我们使用我们预测的均值 μ j t \mu^t_j μjt 作为输入。这种策略确保了我们的预测在开始时基于真实数据,然后在最后几步中逐渐依赖于模型自身的预测,以评估模型对长期动态的建模能力。

2.8.训练模型

在介绍了模型的所有组件之后,我们现在详细阐述训练过程。给定训练样本 x x x,我们遵循以下步骤进行训练:

  1. 编码:首先,我们使用编码器处理输入数据 x x x,并计算潜在变量 z i j z_{ij} zij 的后验分布 q ( z i j ∣ x ) q(z_{ij}|x) q(zijx)

  2. 采样:接着,我们从 q ( z i j ∣ x ) q(z_{ij}|x) q(zijx) 的具体可重参数化近似中采样潜在变量 z i j z_{ij} zij

  3. 解码:然后,我们使用解码器,以采样得到的 z i j z_{ij} zij 和原始输入 x x x 作为条件,计算输出状态 x t x_t xt 的预测均值 μ t \mu_t μt

  4. 计算ELBO:接下来,我们计算证据下界(ELBO)目标函数,它由两部分组成:重建误差和KL散度。

    • 重建误差:我们使用均方误差(MSE)作为重建误差的度量,因为它对应于高斯似然的对数概率的一个近似(在假设高斯分布的方差为常数时)。具体地,我们计算预测均值 μ t j \mu_t^j μtj 和真实值 x t j x_t^j xtj 之间的欧氏距离的平方和,并可能加上一个常数项以匹配方程(3)中的对数概率形式:

      ∑ t = 2 T ∥ x t j − μ t j ∥ 2 2 + const ( 18 ) \sum_{t=2}^{T} \left\| x_t^j - \mu_t^j \right\|_2^2 + \text{const} \quad (18) t=2Txtjμtj22+const(18)

    • KL散度:对于KL散度项,如果先验分布 p ( z ) p(z) p(z) 是均匀的,那么KL散度可以简化为后验分布 q ( z i j ∣ x ) q(z_{ij}|x) q(zijx) 的熵的总和(加上一个常数项)。这是因为均匀分布的KL散度是熵的线性变换。具体地,我们有:

      ∑ i < j H ( q ( z i j ∣ x ) ) + const ( 19 ) \sum_{i<j} H(q(z_{ij}|x)) + \text{const} \quad (19) i<jH(q(zijx))+const(19)

  5. 优化:由于我们使用了可重参数化技巧(如重参数化技巧用于高斯分布),我们可以将ELBO目标函数中的随机变量替换为其可重参数化的形式,从而通过反向传播计算梯度并优化模型参数。

  6. 迭代:我们重复上述步骤,通过迭代优化模型参数,直到模型在验证集上达到满意的性能或达到预设的训练轮次。

在训练过程中,我们可能需要调整一些超参数,如学习率、批量大小、KL散度的权重等,以平衡重建误差和KL散度项,并防止过拟合或欠拟合。

3.NRI模型实验

3.1.设置

这段代码是一个Python脚本,用于下载数据集、解压数据集、导入必要的库,并准备进行数据分析或机器学习任务。以下是详细的中文注释:

# 下载数据集并保存为本地文件 'nri_springs.zip'
!wget -O 'nri_springs.zip' https://surfdrive.surf.nl/files/index.php/s/LXV9iJjxfu5jhdD/download
# 解压 'nri_springs.zip' 到 'data' 目录
!unzip -u nri_springs.zip -d data
# 导入时间处理模块
import time
# 导入命令行参数解析模块
import argparse
# 导入数据序列化模块
import pickle
# 导入文件路径操作模块
import os
# 导入日期时间处理模块
import datetime
# 导入PyTorch深度学习框架
import torch
# 导入PyTorch神经网络模块
import torch.nn as nn
# 导入PyTorch神经网络函数库
import torch.nn.functional as F
# 导入NumPy科学计算库
import numpy as np
# 注释掉的导入数学模块
#import math
# 从PyTorch工具包中导入TensorDataset数据集类
from torch.utils.data.dataset import TensorDataset
# 导入PyTorch数据加载器模块
from torch.utils.data import DataLoader
# 导入PyTorch自动微分Variable类
from torch.autograd import Variable
# 导入PyTorch优化器模块
import torch.optim as optim
# 导入PyTorch学习率调度器模块
from torch.optim import lr_scheduler
# 导入NetworkX图计算库
import networkx
# 导入Matplotlib绘图库
import matplotlib.pyplot as plt
# 尝试导入NetworkX库,如果失败则自动安装
try:
    import networkx
except ModuleNotFoundError:
    !pip install --quiet networkx
    import networkx
# 设置Matplotlib绘图在Jupyter笔记本中内联显示
%matplotlib inline

3.2.数据预处理

加载数据、将数组转换为PyTorch张量,并返回数据加载器对象。

# 定义加载数据的函数,指定批量大小和数据文件后缀
def load_data(batch_size=1, suffix=''):
    # 使用numpy加载训练、验证和测试数据集的位置、速度和边信息
    loc_train = np.load('data/loc_train' + suffix + '.npy')
    vel_train = np.load('data/vel_train' + suffix + '.npy')
    edges_train = np.load('data/edges_train' + suffix + '.npy')
    # ...(类似地加载验证和测试数据集)

    # 获取粒子数量
    num_atoms = loc_train.shape[3]

    # 计算位置和速度数据的最大值和最小值,用于归一化
    loc_max = loc_train.max()
    loc_min = loc_train.min()
    vel_max = vel_train.max()
    vel_min = vel_train.min()

    # 将位置和速度数据归一化到[-1, 1]区间
    loc_train = (loc_train - loc_min) * 2 / (loc_max - loc_min) - 1
    # ...(类似地归一化速度数据)

    # 重新排列数据形状,连接位置和速度数据,并将边信息重塑为一维数组
    loc_train = np.transpose(loc_train, [0, 3, 1, 2])
    vel_train = np.transpose(vel_train, [0, 3, 1, 2])
    feat_train = np.concatenate([loc_train, vel_train], axis=3)
    # ...(类似地处理验证和测试数据集)

    # 将特征和边信息转换为PyTorch张量
    feat_train = torch.FloatTensor(feat_train)
    edges_train = torch.LongTensor(edges_train)
    # ...(类似地转换验证和测试数据集的张量)

    # 排除自环(即节点和自己连接的边)
    off_diag_idx = np.ravel_multi_index(
        np.where(np.ones((num_atoms, num_atoms)) - np.eye(num_atoms)),
        [num_atoms, num_atoms])
    edges_train = edges_train[:, off_diag_idx]
    # ...(类似地处理验证和测试数据集)

    # 创建TensorDataset数据集和DataLoader加载器
    train_data = TensorDataset(feat_train, edges_train)
    train_data_loader = DataLoader(train_data, batch_size=batch_size)
    # ...(类似地创建验证和测试数据集的加载器)

    # 返回训练、验证和测试数据集的加载器,以及位置和速度的最大和最小值
    return train_data_loader, valid_data_loader, test_data_loader, loc_max, loc_min, vel_max, vel_min

代码中定义了一个load_data函数,它使用指定的批量大小和数据文件后缀来加载数据集。加载的数据包括位置、速度和边信息,并将它们转换为PyTorch张量。数据被归一化到[-1, 1]区间,并排除了自环。然后,创建了TensorDatasetDataLoader对象,以便在模型训练中使用批量加载数据。

接下来是使用load_data函数来指定批量大小为128,并加载特定后缀的数据:

# 指定批量大小和数据文件后缀,调用load_data函数加载数据
train_loader, valid_loader, test_loader, _, _, _, _ = load_data(128, "_springs5")

然后,代码中展示了如何从加载器中获取第一个小批量数据,并打印出数据的形状:

# 从训练数据加载器中获取一个迭代器,并取出第一个小批量数据
(x_sample, rel_sample) = next(iter(train_loader))
print(x_sample.shape)  # 打印位置和速度数据的形状
print(rel_sample.shape)  # 打印交互图边信息的形状

最后,代码中展示了如何查看交互图,将边信息列表转换为邻接矩阵,并绘制交互图:

# 打印特定粒子的交互信息
idx = 0
print(rel_sample[idx])

# 定义函数将边信息列表转换为邻接矩阵
def list_to_adj(rel):
    b = torch.zeros((5*5))
    for i in range(4):
        b[i*5+i+1:(i+1)*5+(i+1)] = rel[i*5:(i+1)*5]
    return b.reshape((5,5))

# 转换交互信息为邻接矩阵并打印
b = list_to_adj(rel_sample[idx])
print(b.reshape((5,5)))

# 定义函数绘制图
def show_graph(b):
    g = b.reshape((5,5)).cpu().numpy()
    graph = networkx.from_numpy_array(g)
    networkx.draw(graph, with_labels=True)
    plt.show()

# 绘制交互图
show_graph(b)

代码展示了如何分析和可视化粒子间的交互关系。通过将边信息转换为邻接矩阵,可以更直观地理解粒子间的相互作用。
在这里插入图片描述

3.3.定义多层感知器

以下代码定义了包括超参数、创建关系掩码以及多层感知器(MLP)类。

# 检查第一个数据点的轨迹数据
x_sample[idx].shape
# torch.Size([5, 49, 4]) 表示每个条目包含5个粒子的轨迹,共有49个时间步,每个状态由4维向量指定。

# 打印第一个粒子在前两个时间步的位置和速度
x_sample[idx, 0, 0:2, :]
# tensor([[-0.1272, -0.0987, -0.3305,  0.1197],
#         [-0.1380, -0.0945, -0.3275,  0.1196]])

# 定义一些超参数
dims = 4  # 特征维度
num_atoms = 5  # 粒子数量
timesteps = 49  # 时间步数
lr = 0.0005  # 学习率
temp = 0.5  # 温度参数,用于Gumbel-Softmax采样
output_var = 5e-5  # 输出方差
_EPS = 1e-10  # 用于数值稳定性的小常数
# 编码器在完全连接的图上工作,以预测交互图
# 定义关系掩码以指定哪些顶点从其他顶点接收消息

def encode_onehot(labels):
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32)
    return labels_onehot

# 创建一个全1矩阵,然后减去单位矩阵以得到非对角线元素
off_diag = np.ones([num_atoms, num_atoms]) - np.eye(num_atoms)

# 使用encode_onehot函数编码非对角线元素的接收和发送关系掩码
rel_rec = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32)
rel_send = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32)
# 将关系掩码转换为PyTorch张量
rel_rec = torch.FloatTensor(rel_rec)
rel_send = torch.FloatTensor(rel_send)
# 使用这些掩码,我们可以在完全连接的图上传递消息,将边特征转换为节点特征,将节点特征转换为边特征
print(rel_rec.t(), rel_rec.shape)
# 输出关系的转置矩阵和形状

# 例如,要将20个交互的边特征转换为节点特征,我们可以将上述矩阵与边特征向量相乘
# torch.matmul(rel_rec.t(), x)
# 这将为每个顶点收集所有邻居节点的消息并累加这些消息
# 接下来,我们定义一个简单的MLP类,用于非线性特征转换

class MLP(nn.Module):
    """两层全连接ELU网络,带批量归一化。"""

    def __init__(self, n_in, n_hid, n_out, do_prob=0.):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(n_in, n_hid)  # 第一层全连接
        self.fc2 = nn.Linear(n_hid, n_out)  # 第二层全连接
        self.bn = nn.BatchNorm1d(n_out)  # 批量归一化
        self.dropout_prob = do_prob  # dropout概率

        self.init_weights()  # 初始化权重

    # 初始化权重
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal(m.weight.data)  # 使用Xavier初始化权重
                m.bias.data.fill_(0.1)  # 设置偏置为0.1
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)  # 设置批量归一化权重为1
                m.bias.data.zero_()  # 设置批量归一化偏置为0

    # 批量归一化函数
    def batch_norm(self, inputs):
        x = inputs.view(inputs.size(0) * inputs.size(1), -1)
        x = self.bn(x)
        return x.view(inputs.size(0), inputs.size(1), -1)

    # 前向传播函数
    def forward(self, inputs):
        # 输入形状: [num_sims, num_things, num_features]
        x = F.elu(self.fc1(inputs))  # 第一层激活函数
        x = F.dropout(x, self.dropout_prob, training=self.training)  # dropout
        x = F.elu(self.fc2(x))  # 第二层激活函数
        return self.batch_norm(x)  # 批量归一化输出

代码首先检查了数据集中某个粒子的轨迹数据的形状,然后定义了模型训练所需的一些超参数。接着,代码创建了用于在图神经网络中传递消息的关系掩码,这些掩码定义了哪些节点可以接收来自其他节点的消息。最后,代码定义了一个简单的多层感知器(MLP)类,这个类将被用于模型中的特征转换,包括两层全连接层和批量归一化层,以及相应的权重初始化和前向传播逻辑。

3.4.定义编码器

代码实现了一个图神经网络(GNN)架构的变分自编码器(VAE)编码器,用于分析每个粒子的轨迹数据。该编码器的目的是将每个粒子的完整轨迹信息综合成一个单一的特征向量。通过多层网络的转换,编码器最终生成用于图采样过程中边概率的参数。

具体来说,编码器首先通过多个MLP(多层感知器)层对节点特征进行处理,然后通过特定的消息传递机制将节点特征转化为边特征,再转化回节点特征,最终再次转化为边特征。这个过程中涉及的转换操作包括节点到边的消息传递和边到节点的消息传递,这些操作都是基于预先定义的关系掩码来实现的。通过这种方式,编码器能够捕捉粒子间的复杂相互作用。

编码器的输出是一个经过线性变换的特征向量,这个向量将作为图采样操作的输入,用于确定潜在图结构中边的存在概率。这个过程是通过一个输出层完成的,该层将MLP层的输出映射到所需的输出维度,从而为后续的图采样提供参数。

整个编码器的设计允许模型学习如何从观察到的轨迹中推断出潜在的交互结构,这是理解和预测粒子动态的关键。通过这种方式,编码器为VAE模型提供了一种强大的机制,用以探索和利用图结构数据中的复杂关系。

# 定义MLPEncoder类,它是一个图神经网络,用作VAE的编码器
class MLPEncoder(nn.Module):
    def __init__(self, n_in, n_hid, n_out=2, do_prob=0.):
        super(MLPEncoder, self).__init__()

        # 定义四个MLP层,用于特征转换
        self.mlp1 = MLP(n_in, n_hid, n_hid, do_prob)
        self.mlp2 = MLP(n_hid * 2, n_hid, n_hid, do_prob)
        self.mlp3 = MLP(n_hid, n_hid, n_hid, do_prob)
        self.mlp4 = MLP(n_hid * 3, n_hid, n_hid, do_prob)

        # 输出层,将隐藏层特征映射到输出维度
        self.fc_out = nn.Linear(n_hid, n_out)
        self.init_weights()  # 初始化权重

    # 初始化权重的函数
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal(m.weight.data)  # 使用Xavier初始化权重
                m.bias.data.fill_(0.1)  # 设置偏置为0.1

    # 将边特征转换为节点特征的函数
    def edge2node(self, x, rel_rec, rel_send):
        incoming = torch.matmul(rel_rec.t(), x)  # 通过关系矩阵接收消息
        return incoming / incoming.size(1)  # 归一化

    # 将节点特征转换为边特征的函数
    def node2edge(self, x, rel_rec, rel_send):
        receivers = torch.matmul(rel_rec, x)  # 通过关系矩阵发送消息
        senders = torch.matmul(rel_send, x)
        edges = torch.cat([senders, receivers], dim=2)  # 连接发送者和接收者特征
        return edges

    # 前向传播函数
    def forward(self, inputs, rel_rec, rel_send):
        x = inputs.view(inputs.size(0), inputs.size(1), -1)  # 调整输入形状

        x = self.mlp1(x)  # 第一个MLP层处理每个节点的特征

        x = self.node2edge(x, rel_rec, rel_send)  # 节点到边的转换
        x = self.mlp2(x)
        x_skip = x  # 保存当前特征用于跳跃连接

        x = self.edge2node(x, rel_rec, rel_send)  # 边到节点的转换
        x = self.mlp3(x)
        x = self.node2edge(x, rel_rec, rel_send)  # 再次进行节点到边的转换
        x = torch.cat((x, x_skip), dim=2)  # 跳跃连接
        x = self.mlp4(x)  # 第四个MLP层处理边特征

        return self.fc_out(x)  # 输出层

代码通过MLPEncoder类实现了一个图神经网络,该网络包含四个MLP层和两个消息传递步骤(节点到边和边到节点)。每个MLP层后面跟着一个非线性激活函数(ELU)和可选的dropout层。消息传递步骤使用定义好的关系掩码rel_recrel_send来实现。最终,编码器的输出是一个线性层,它将隐藏层的特征映射到输出维度,用于图采样操作的边概率参数化。

3.5.定义解码器

代码实现了一个解码器,该解码器是变分自编码器(VAE)的一个组成部分,它利用图神经网络(GNN)对节点和边的特征进行处理,并通过多层感知器(MLP)进行预测。解码器的目的是基于系统的初始状态来推断其未来的轨迹。以下是对这段描述的改写:

所提供的代码构建了VAE的解码器模块,该模块旨在从图结构数据中预测物体的未来运动轨迹。解码器首先接收物体的初始状态,然后通过一系列图神经网络层来分析节点间的交互和特征。这些特征随后被送入一个多层感知器,该感知器负责计算并输出对物体未来状态的预测。解码器的设计允许它在多个时间步骤上进行操作,逐步构建出完整的轨迹预测。

在这个解码器中,消息传递机制是关键,它允许节点基于其邻居的信息来更新自己的特征表示。解码器通过这种方式能够捕捉到物体间的复杂相互作用,并将这些信息融合到对未来状态的预测中。每一层的输出都会被送入后续的MLP中,以便进一步提取和转换特征,最终生成对下一时间步的预测。

通过这种方式,解码器不仅能够考虑到单个物体的运动,还能够考虑到物体间的相互影响,从而提供更为准确和全面的轨迹预测。这种基于图的预测方法在处理多体系统,如群体运动、分子动力学等领域中具有重要的应用价值。

# 定义MLPDecoder类,用作VAE的解码器模块
class MLPDecoder(nn.Module):
    """MLP解码器模块."""

    def __init__(self, n_in_node, edge_types, msg_hid, msg_out, n_hid, do_prob=0.):
        super(MLPDecoder, self).__init__()

        # 定义消息传递层和输出层的全连接网络
        self.msg_fc1 = nn.Linear(2 * n_in_node, msg_hid)
        self.msg_fc2 = nn.Linear(msg_hid, msg_out)
        self.msg_out_shape = msg_out

        self.out_fc1 = nn.Linear(n_in_node + msg_out, n_hid)
        self.out_fc2 = nn.Linear(n_hid, n_hid)
        self.out_fc3 = nn.Linear(n_hid, n_in_node)

        print('使用学习到的交互网络解码器.')
        self.dropout_prob = do_prob

    # 单步预测函数
    def single_step_forward(self, single_timestep_inputs, rel_rec, rel_send, single_timestep_rel_type):
        # 节点到边的消息传递
        receivers = torch.matmul(rel_rec, single_timestep_inputs)
        senders = torch.matmul(rel_send, single_timestep_inputs)
        pre_msg = torch.cat([senders, receivers], dim=-1)

        # 通过激活函数和dropout处理消息
        msg = F.relu(self.msg_fc1(pre_msg))
        msg = F.dropout(msg, p=self.dropout_prob)
        msg = F.relu(self.msg_fc2(msg))
        msg = msg * single_timestep_rel_type[:, :, :, 1:2]

        # 将消息聚合到接收节点
        agg_msgs = msg.transpose(-2, -1).matmul(rel_rec).transpose(-2, -1)
        agg_msgs = agg_msgs.contiguous()

        # 跳跃连接
        aug_inputs = torch.cat([single_timestep_inputs, agg_msgs], dim=-1)

        # 输出MLP
        pred = F.dropout(F.relu(self.out_fc1(aug_inputs)), p=self.dropout_prob)
        pred = F.dropout(F.relu(self.out_fc2(pred)), p=self.dropout_prob)
        pred = self.out_fc3(pred)

        # 预测位置/速度的差异
        return single_timestep_inputs + pred

    # 前向传播函数
    def forward(self, inputs, rel_type, rel_rec, rel_send, pred_steps=1):
        # 假定所有样本使用相同的图

        inputs = inputs.transpose(1, 2).contiguous()

        # 展开关系类型张量以匹配输入尺寸
        sizes = [rel_type.size(0), inputs.size(1), rel_type.size(1), rel_type.size(2)]
        rel_type = rel_type.unsqueeze(1).expand(sizes)

        # 初始化预测列表
        time_steps = inputs.size(1)
        assert (pred_steps <= time_steps)
        preds = []

        # 初始步骤
        last_pred = inputs[:, 0:1, :, :]
        curr_rel_type = rel_type[:, 0:1, :, :]  # 假定关系类型在所有时间步是常量

        # 运行n步预测
        for step in range(0, pred_steps):
            last_pred = self.single_step_forward(last_pred, rel_rec, rel_send, curr_rel_type)
            preds.append(last_pred)

        # 重新组装正确的时间线
        output = torch.zeros(sizes)
        if inputs.is_cuda:
            output = output.cuda()
        for i in range(len(preds)):
            output[:, i:i+1, :, :] = preds[i]

        # 返回预测的轨迹,排除最后一步
        pred_all = output[:, :(inputs.size(1) - 1), :, :]

        return pred_all.transpose(1, 2).contiguous()

接下来的代码块创建了辅助数组,用于在消息传递步骤中指定发送和接收节点,这些节点在完全连接的图中对应于邻居节点:

# 生成非对角线交互图
off_diag = np.ones([num_atoms, num_atoms]) - np.eye(num_atoms)

# 根据非对角线元素生成关系接收和发送掩码
rel_rec = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32)
rel_send = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32)
rel_rec = torch.FloatTensor(rel_rec)
rel_send = torch.FloatTensor(rel_send)

# 编码器和解码器的初始化
encoder = MLPEncoder(timesteps * dims, 256, 2)
decoder = MLPDecoder(n_in_node=dims,
                     edge_types=2,
                     msg_hid=256,
                     msg_out=256,
                     n_hid=256, )
print('使用学习到的交互网络解码器.')

# 警告信息,表示nn.init.xavier_normal已被弃用,建议使用nn.init.xavier_normal_
# 将模型和关系掩码转换为CUDA张量
encoder.cuda()
decoder.cuda()
rel_rec = rel_rec.cuda()
rel_send = rel_send.cuda()

# 定义优化器
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()),
                       lr=lr)

# 加载数据
train_loader, valid_loader, test_loader, loc_max, loc_min, vel_max, vel_min = load_data(
    128, "_springs5")

代码通过MLPDecoder类实现了解码器,它接受初始状态作为输入,并尝试预测剩余的轨迹。解码器首先通过single_step_forward函数来进行单步预测,该函数将当前状态和提议的交互图作为输入,通过节点状态生成边特征,然后通过MLP计算下一个预测作为当前状态的差异。这个过程对所有时间步重复进行。在forward函数中,解码器将输入数据转换为正确的形状,并通过一系列预测步骤生成最终的轨迹预测。最后,代码初始化了编码器和解码器,并将它们转换为CUDA张量以在GPU上运行,然后定义了优化器并加载了数据。

3.6.训练模型

代码详细阐述了通过变分自编码器(VAE)进行模型训练的过程,涵盖了关键步骤:利用Gumbel-Softmax技术进行离散图采样、定义损失函数以评估预测准确性,以及通过反向传播更新模型权重。

首先,代码实现了Gumbel-Softmax采样,这是一种从连续分布中采样出离散样本的技术,适用于VAE中潜在变量的采样。通过定义my_softmax函数来执行一维softmax操作,sample_gumbel函数用于生成Gumbel噪声,而gumbel_softmax_samplegumbel_softmax函数则用于将这些噪声整合到潜在变量的采样中。

接着,代码中定义了几个辅助函数,包括kl_categorical_uniform用于计算分类问题的KL散度,nll_gaussian用于计算高斯分布下的负对数似然,以及edge_accuracy用于计算预测的准确率。

之后,train函数详细描述了模型训练的流程。在每个训练周期中,模型首先进入训练模式,然后对训练数据集进行遍历,执行编码器和解码器的前向传播,计算重构损失和KL散度,随后执行反向传播并更新模型参数。训练过程中还包括了对模型在验证集上性能的评估。

最后,代码中的训练循环设置了训练周期数,并在每个周期结束后评估模型在验证集上的性能,记录最低验证损失所对应的周期,从而追踪模型的最佳表现。

整体而言,这段代码为使用变分自编码器进行图结构数据建模提供了一个完整的训练框架,包括图的离散采样、损失函数的定义、模型参数的更新,以及训练过程中的性能监控。

# 离散采样
# 我们将通过使用编码器输出的权重选择完全连接图的子集来采样图,使用Gumbel softmax方法。
# 下面定义了Gumbel softmax的例程。

def my_softmax(input, axis=1):
    # 一维softmax操作的自定义实现
    trans_input = input.transpose(axis, 0).contiguous()
    soft_max_1d = F.softmax(trans_input)
    return soft_max_1d.transpose(axis, 0)

def sample_gumbel(shape, eps=1e-10):
    # 采样Gumbel分布的噪声
    U = torch.rand(shape).float()
    return - torch.log(eps - torch.log(U + eps))

def gumbel_softmax_sample(logits, tau=1, eps=1e-10):
    # 根据Gumbel分布采样并进行softmax操作
    gumbel_noise = sample_gumbel(logits.size(), eps=eps)
    if logits.is_cuda:
        gumbel_noise = gumbel_noise.cuda()
    y = logits + Variable(gumbel_noise)
    return my_softmax(y / tau, axis=-1)

def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10):
    # 根据Gumbel-Softmax技术进行采样,可以选择软硬采样
    y_soft = gumbel_softmax_sample(logits, tau=tau, eps=eps)
    if hard:
        # 硬采样,即one-hot采样
        shape = logits.size()
        _, k = y_soft.data.max(-1)
        y_hard = torch.zeros(*shape)
        if y_soft.is_cuda:
            y_hard = y_hard.cuda()
        y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0)
        y = Variable(y_hard - y_soft.data) + y_soft
    else:
        y = y_soft
    return y

# 下面定义了一些辅助函数,用于计算准确率、预测损失和KL散度。

def kl_categorical_uniform(preds, num_atoms, num_edge_types, add_const=False, eps=1e-16):
    # 计算KL散度
    kl_div = preds * torch.log(preds + eps)
    if add_const:
        const = np.log(num_edge_types)
        kl_div += const
    return kl_div.sum() / (num_atoms * preds.size(0))

def nll_gaussian(preds, target, variance, add_const=False):
    # 计算高斯分布下的负对数似然
    neg_log_p = ((preds - target) ** 2 / (2 * variance))
    if add_const:
        const = 0.5 * np.log(2 * np.pi * variance)
        neg_log_p += const
    return neg_log_p.sum() / (target.size(0) * target.size(1))

def edge_accuracy(preds, target):
    # 计算边缘准确率
    _, preds = preds.max(-1)
    correct = preds.float().data.eq(target.float().data.view_as(preds)).cpu().sum()
    return np.float(correct) / (target.size(0) * target.size(1))

# 现在我们可以开始训练模型了。这包括运行编码器来获取潜在图参数,使用gumbel_softmax函数采样边,得到潜在图,然后将其传递给解码器。
# 我们使用均匀分类先验来计算VAE损失的KL散度,并使用固定方差的高斯似然损失来计算预测。

def train(epoch, best_val_loss):
    # 训练函数定义
    # ...
    # 训练过程包括对训练数据加载器的遍历,编码器和解码器的前向传播,损失的计算和反向传播,以及参数的更新。

    # 打印训练和验证过程中的损失和准确率等信息
    print('Epoch: {:04d}'.format(epoch),
          # ...
          'acc_val: {:.10f}'.format(np.mean(acc_val)),
          'time: {:.4f}s'.format(time.time() - t))
    return np.mean(nll_val)

# 训练循环
t_total = time.time()
best_val_loss = np.inf
best_epoch = 0
for epoch in range(10):
    val_loss = train(epoch, best_val_loss)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_epoch = epoch

3.7.可视化预测结果

以下代码展示了如何可视化一些示例中的真实和预测的交互图。

# 可视化发现的图
# 现在我们可以为一些示例可视化实际和预测的交互图。

# 从验证数据加载器中获取数据和关系
data, relations = next(iter(valid_loader))
# 将数据和关系转换为CUDA张量以便在GPU上进行计算
data = data.cuda()
relations = relations.cuda()

# 使用编码器处理数据,获取潜在变量的对数几率(logits)
logits = encoder(data, rel_rec, rel_send)
# 通过最大化最后一个维度来获取预测的关系
_, rel = logits.max(-1)

# 打印第一个样本的预测和实际关系
print(rel[0])  # 预测的关系
print(relations[0])  # 实际的关系

# 循环遍历每个样本,并将其实际和预测的交互图转换为邻接矩阵形式
for i in range(5):
    # 将实际和预测的关系转换为邻接矩阵
    g_act = list_to_adj(relations[i])  # 实际的交互图邻接矩阵
    g_pred = list_to_adj(rel[i])  # 预测的交互图邻接矩阵

    print("Original")  # 打印实际的交互图
    show_graph(g_act)  # 显示实际的交互图
    print("Predicted")  # 打印预测的交互图
    show_graph(g_pred)  # 显示预测的交互图

代码中,valid_loader是一个迭代器,用于从验证数据集中获取数据。datarelations分别代表验证数据和对应的关系。使用编码器encoder对数据进行处理,得到潜在变量的对数几率(logits),然后通过最大化操作获取预测的关系rel

接着,代码打印出第一个样本的预测关系和实际关系,并使用一个循环来为每个样本生成实际和预测的交互图的邻接矩阵。list_to_adj函数用于将关系列表转换为邻接矩阵。然后,使用show_graph函数来可视化这些交互图。

在这里插入图片描述

原始图像

在这里插入图片描述
预测图像

4. 总结和展望

4.1 总结

本文介绍了神经关系推理(Neural Relational Inference, NRI)模型,这是一种基于变分自编码器(VAE)的图神经网络(GNN),用于从观测数据中同时学习系统的动态和组分间的相互作用。NRI模型在多个方面展现了其强大的能力:

  • 无监督学习:NRI能够在没有显式交互信息的情况下,通过观测到的轨迹数据来学习系统的动态和潜在的交互图。
  • 图结构建模:通过GNN对图结构数据进行操作,模型能够捕捉个体间的复杂关系,并用于动态预测。
  • 多步预测:不同于传统的VAE,NRI训练解码器进行多步预测,更好地模拟了动态系统的时间序列特性。
  • 连续松弛:针对离散潜在变量,NRI采用了连续松弛技术,使得模型参数可以通过反向传播进行优化。

在模拟物理系统、真实运动捕捉和NBA篮球追踪数据等多个实验中,NRI模型都展现出了优越的性能,能够高精度地恢复交互图并预测未来的动态变化。

4.2 展望

尽管NRI模型在多个领域取得了显著的成果,但仍有若干潜在的改进方向和未来的研究方向:

  • 动态图结构学习:当前的NRI模型在训练时假设图结构是静态的。未来的工作可以探索如何让模型在训练阶段就能捕捉到图结构的动态变化。
  • 更复杂的系统:NRI模型可以进一步应用于更复杂的系统,如交通流、社交网络动态等,这些系统具有高度的动态性和非线性特征。
  • 模型泛化能力:研究NRI模型在不同类型数据和任务上的泛化能力,以及如何通过正则化技术提高模型的泛化性。
  • 计算效率:随着系统规模的增大,如何提高NRI模型的计算效率和可扩展性也是一个重要的研究方向。
  • 与其他模型的结合:探索NRI与其他模型(如注意力机制、序列模型等)的结合,以利用各自的优势解决更复杂的问题。

NRI模型作为一种新兴的图神经网络方法,在理解和预测相互作用系统方面具有巨大的潜力。随着进一步的研究和发展,它有望在更多领域发挥重要作用。

参考文献

  1. Kipf, T. N., & Welling, M. (2017). Semi-supervised classification with graph convolutional networks. In International Conference on Learning Representations (ICLR).
  2. Gilmer, J., Schoenholz, S. S., Riley, P. F., Vinyals, O., & Dahl, G. E. (2017). Neural message passing for quantum chemistry. In International Conference on Machine Learning (ICML).
  3. Battaglia, P. W., Pascanu, R., Lai, M., Rezende, D. J., & Kavukcuoglu, K. (2016). Interaction networks for learning about objects, relations and physics. In Advances in Neural Information Processing Systems (NIPS).
  4. Kingma, D. P., & Ba, J. (2015). Adam: A method for stochastic optimization. In International Conference on Learning Representations (ICLR).
  5. Maddison, C. J., Mnih, A., & Teh, Y. W. (2017). The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables. In International Conference on Learning Representations (ICLR).
  6. Sukhbaatar, S., Szlam, A., & Fergus, R. (2016). Learning multiagent communication with backpropagation. In Advances in Neural Information Processing Systems (NIPS).
  7. Watters, N., Zoran, D., Weber, T., Battaglia, P., Pascanu, R., & Tacchetti, A. (2017). Visual interaction networks: Learning a physics simulator from video. In Advances in Neural Information Processing Systems (NIPS).
评论 25
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

MUKAMO

你的鼓励是我们创作最大的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值