人工智能咨询培训老师叶梓 转载标明出处
时空学习模型通过融合多种空间卷积和时间演化技术,有效捕捉城市数据的复杂异质性。然而,快速的城市化进程导致城市数据分布和结构频繁变动,这与现有模型假设训练和测试数据独立同分布的情况相悖。现实世界中,城市时空要素的扩张和增长引发分布偏移,使得模型在遭遇新的数据实例时,其泛化能力受到限制,难以灵活适应数据的快速演变。为了克服这一难题,中国科学技术大学的研究团队提出了一种创新的互补时空学习系统ComS2T,该系56统通过促进模型自适应演化,有效应对数据分布变化带来的挑战。
方法
ComS2T模型的提出基于神经科学中的互补学习理论,该理论指出大脑中不同区域在记忆历史知识和吸收新知识方面发挥着不同的角色。ComS2T模型通过有效地将学习权重解耦为稳定的新皮层和动态的海马体两个互补子空间,实现了对时空数据的动态适应。其中,新皮层负责巩固历史记忆,而海马体则负责更新新知识。
ComS2T模型的核心在于两个主要方面。首先,它通过显式建模学习行为,高效地将可学习神经权重解耦为稳定的新皮层和动态的海马体两个互补子空间。这两个结构协同工作,动态适应流式的时空数据。其次,ComS2T通过自监督学习预训练时空提示,弥补了可学习提示和特定数据模式之间的差距。这种预训练策略允许在测试时进行训练,使模型对数据分布的变化更加敏感。
为了实现这一目标,ComS2T设计了一个渐进式学习架构(Figure 2),包括四个主要组成部分:高效的神经元解耦、提示预训练、基于提示的微调和测试时的自适应。这一架构实现了从粗粒度到细粒度的数据适应。
在神经解耦方面,ComS2T提出了一种高效的解耦方法,将集成神经网络中的潜在新皮层和海马体结构解耦。通过分别将空间模块和时间模块作为独立的单元,并假设有K层空间聚合和L次时间卷积,研究团队将空间邻接性和特征级缩放作为空间可学习空间WS,将时间可学习权重集合指定为WT。通过差分累积策略,捕捉训练过程中的学习行为演化,从而实现对模型行为的表征。
另外ComS2T通过自监督学习进行时空提示的预训练,这些提示作为中间变量,将数据的变化传递给主模型。通过选择信息丰富的空间和时间信号作为提示的基本元素,并显式地将分布建模为对连续观测的数据摘要,构建了时空提示和数据摘要之间的问答对,使提示对数据分布敏感。
最后,ComS2T通过两阶段训练(热身和微调)将提示学习与渐进式学习能力结合起来。在模型训练期间,通过基于提示的微调,利用新皮层保持稳定权重以传递跨环境的不变关系,并利用时空提示指导海马体随着分布变化进行更新。这种设计创新地固定了时空观测中的不变关系,并通过上下文回归残差观测,使得微调只需要有限的计算。
Algorithm 1 描述了ComS2T的训练过程,包括神经解耦、自监督提示训练、基于提示的海马体结构微调。通过这些步骤,模型可以在训练和测试阶段适应数据分布的变化。
Algorithm 2提供了ComS2T测试过程的详细步骤,包括从测试集中采样部分观测、更新时空提示以及基于更新后的提示进行预测。
想要掌握如何将大模型的力量发挥到极致吗?叶老师带您深入了解 Llama Factory —— 一款革命性的大模型微调工具(限时免费)。
1小时实战课程,您将学习到如何轻松上手并有效利用 Llama Factory 来微调您的模型,以发挥其最大潜力。
CSDN教学平台录播地址:https://edu.csdn.net/course/detail/39987
实验
研究者们收集了四类时空数据,包括交通、空气质量和智能电网,以此来验证所提出的数据自适应学习架构ComS2T。这些数据集的统计信息见Table II。
数据集涵盖了苏州工业园区(SIP)的交通流量监控、美国洛杉矶的交通属性数据(Metr-LA)、覆盖中国184个主要城市的PM2.5浓度(KnowAir)以及相同城市的城区气温数据(Temperature)。
为了模拟时空分布的偏移,研究者们设计了不同的学习-测试场景。这包括在时间维度上构建数据分布偏移,同时在空间维度上模仿结构偏移,如Figure 3所示。例如,对于SIP和Metr-LA这两个高度动态的交通数据集,通过收集相同时间段(例如每天的8:00-16:00)的数据作为模型学习的训练集,而在另一个未见的时间段(例如每天的1:00-7:00)进行测试,以此来模仿时间分布偏移。对于空气质量和气候数据集,由于它们在短期内相对静态但季节性变化,研究者们将全年记录分为四个季度,用其中两个季度的数据进行训练,而在另一个季度进行测试。
通过引入新节点和移除现有节点来实现空间分布偏移。例如,在测试阶段,通过主动屏蔽一系列现有节点,然后在测试阶段将它们重新添加回来,以模拟图结构的新连接。同样,在测试阶段移除一些现有节点,以模仿动态图结构中的节点消失。
Table III展示了ComS2T的配置,包括图波网络(GraphWaveNet, GWN)作为骨干网络,学习率为1e-4,时空提示的维度为(64,16,16,32),稳定新皮层的比例τ为(60%, 60%, 60%, 70%),GNN的隐藏维度为32,TCN核的维度为(12,6,3),批量大小为64,优化器为Adam。
在实现细节方面,对于每个数据集,都按照Sec. IV-B中的设置组织样本组。对于SIP和Metr-LA,取1/3的样本用于训练,即每天的8:00-16:00,16:00-24:00的样本用于验证,0:00-7:00的样本用于测试,其中0:00-1:00的样本用于测试时的数据自适应。对于空气质量和气候数据集,取每年前六个月(1月至6月)的样本用于训练,7月和8月的样本用于验证,10月至12月的样本用于测试,其中9月的样本用于测试时的数据自适应。
在测试阶段,研究者们利用时间上最接近的样本进行测试时模型适应,这允许基于分布偏移更新提示。对于结构偏移,当引入新节点时,利用分布监督学习方案,基于一些新观察更新空间和时间提示。对于邻接矩阵,采用节点复制策略,找到与新节点最接近的节点,并将现有类似节点的邻接性复制到新节点,从而为测试构建扩展的关系空间邻接性。当从时空图中移除现有节点时,重新训练空间和时间提示,并在邻接矩阵中屏蔽移除节点的相应行和列,以进行维度对齐。
在性能评估方面,每个基线模型和ComS2T都实现了五次,报告平均误差。采用平均绝对误差(MAE)作为主要评估指标。误差可以表示为:
其中,是节点i在时间步t的预测观测值,而是相应的真实值。
在与竞争对手的性能比较中,ComS2T在大多数场景下都优于基线模型,无论是在时间分布偏移下的性能提高了0.73%到20.70%,还是在结构偏移下的性能提高了1.19%到17.30%。例如,在Metr-LA数据集上,ComS2T显示出显著的改进,这可能归因于时空提示和主要观测之间的良好规律性。
另外通过消融研究(Ablation study),移除了ComS2T中的特定模块,以验证每个精心设计的组件或学习策略的贡献。例如,不识别动态海马体结构,用整个架构进行OOD推断,或者在没有明确识别海马体和新皮层结构的情况下,用提示训练和更新整个神经架构。这些实验表明,ComS2T在所有变体中都取得了最佳性能,证实了海马体结构的有效性。
最后,通过超参数分析(Hyperparameter analysis),研究者们选择了两个关键的超参数来观察模型随着参数变化的行为。例如,稳定新皮层的比例τ的范围是{50%, 60%, 70%, 80%},时空提示的维度E在{16, 32, 64, 128}之间变化。通过这些实验,研究者们不仅获得了满意的结果,而且还为城市研究提供了洞见。
详细的案例研究和模型探索进一步回答了以下两个研究问题(RQ):提示如何解释动态的空间和时间背景,以及提示是否可以适应主要观测值的分布变化;以及ComS2T架构的可学习参数随着学习过程的行为,解耦和部分更新神经结构是否有效提高了性能和泛化能力。通过在不同背景下对Metr-LA的节点2和11的时空提示进行可视化,研究者们展示了学习到的提示之间的绝对差异,其中相似的天区间显示出一些相似性。这些具有合理区分度和相似性的可视化中间结果表明,自监督学习信号的分布可以有效指导提示的优化过程,并针对不同分布获得不同的提示,允许数据适应被传递到海马体结构的更新。
通过这些详细的实验设置和分析,ComS2T模型不仅在理论上推动了时空学习的研究,而且在实际应用中,如城市交通流量预测、空气质量监测等领域,具有重要的应用价值。通过这种创新的学习框架,可以更有效地处理和预测城市中的复杂时空数据,为智能城市的可持续发展提供强有力的技术支撑。
论文链接:https://arxiv.org/pdf/2403.01738.pdf