2025 | Nature(TabPFN)准确预测小样本数据的表格基础模型

Nature:准确预测小样本数据的表格基础模型

标题Accurate predictions on small data with a tabular foundation model
作者Noah Hollmann, Samuel Müller, Lennart Purucker, Arjun Krishnakumar, Max Körfer, Shi Bin HooRobin Tibor Schirrmeister & Frank Hutter
期刊Nature|2025
机构University of Freiburg
论文Accurate predictions on small data with a tabular foundation model
代码https://github.com/PriorLabs/TabPFN

摘要

表格数据(如行和列组成的电子表格)在科学研究中无处不在,从生物医学到粒子物理、经济和气候科学等领域都有广泛应用。基于表格数据的预测任务(如填补标签列的缺失值)对多种应用场景至关重要,包括生物医学风险模型、药物发现和材料科学等。目前,尽管深度学习在原始数据学习方面取得了革命性进展,但在表格数据领域,梯度提升决策树(如 XGBoost 和 LightGBM)在过去 20 年中一直是主导方法。深度学习在表格数据上的表现通常不如传统方法,尤其是在小数据集上。论文提出了一种新的表格基础模型 TabPFN(Tabular Prior-data Fitted Network)。该模型基于生成式Transformer架构,通过在数百万个合成数据集上进行训练,实现了对小到中型表格数据(样本量≤10,000,特征数≤500)的高效预测。实验表明,TabPFN在分类和回归任务中显著优于现有方法,例如在分类任务中以2.8秒的推理时间超越经过4小时调优的CatBoost集成模型(速度提升5,140倍),同时具备生成数据、密度估计和可解释性等基础模型特性。这一研究为表格数据分析提供了一种端到端学习的范式,有望加速科学发现与决策优化。

1 引言

在人工智能的发展历程中,人工设计的算法组件逐渐被性能更优的端到端学习组件所取代。这一趋势在计算机视觉、自然语言处理和游戏领域表现得尤为明显。然而,表格数据由于其多样性和复杂性,使得深度学习方法在处理这类数据时面临诸多挑战。表格数据的多样性体现在不同数据集中的相同值可能具有完全不同的含义,这导致了大量小型、独立数据集及其相关模型的出现。此外,表格数据本身具有异质性,包含不同尺度和类型(如布尔型、分类型、序数型、整数型、浮点型)的列,以及不平衡或缺失的数据、不重要的特征和异常值等。这些特性使得非深度学习方法(如基于树的模型)在表格数据处理中表现更为出色。不过,传统方法(如梯度提升决策树)虽在单数据集上表现优异,但难以实现跨数据集的知识迁移,且无法建模预测不确定性。

深度学习的局限与机遇:尽管深度学习在图像和自然语言处理领域取得突破,但在表格数据上的应用仍受限。原因包括:(1)表格数据缺乏空间或序列的局部相关性;(2)小样本场景下模型易过拟合;(3)传统树模型在计算效率上具有优势。然而,基于Transformer的大模型通过上下文学习(In-Context Learning, ICL)展现了强大的算法学习能力,为表格数据分析提供了新思路。

本文提出TabPFN,首次将基础模型范式引入表格数据分析,核心创新包括:

  1. 基于合成数据的预训练框架:通过结构因果模型生成百万级合成数据集,覆盖真实数据的多样性挑战(如缺失值、异常值、无关特征);

  2. 双向注意力架构:设计针对表格结构的Transformer层,实现样本与特征的双向交互,支持快速推理与状态缓存;

  3. 多功能基础模型特性:除预测外,支持数据生成、密度估计、特征嵌入与微调,为下游任务提供统一工具。

2 方法

2.1 有原则的上下文学习

TabPFN 利用了上下文学习(in-context learning,ICL),这种机制也是大型语言模型取得惊人性能的原因,从而生成了一种全新的、完全基于学习的表格预测算法。尽管 ICL 最初是在大型语言模型中观察到的,但近期的研究表明,Transformers可以通过 ICL 学习简单的算法,例如逻辑回归。此外,先验数据拟合网络(PFNs)也表明,即使是复杂的算法,如高斯过程和贝叶斯神经网络,也可以通过 ICL 进行近似。

TabPFN 的核心思想是生成大量合成表格数据集,然后训练一个基于Transformer的神经网络来学习解决这些合成预测任务。与传统方法不同,传统方法需要为数据挑战(如缺失值)设计手工解决方案,而我们的方法通过解决包含这些挑战的合成任务,自主学习有效的策略。这种方法利用 ICL 作为一种基于示例的算法编程框架。我们通过生成多样化的合成数据集来设计期望的算法行为,然后训练一个模型来编码满足这种行为的算法。这将算法设计过程从编写显式指令转变为定义输入-输出示例,为各个领域的算法创建开辟了新的可能性。在这里,我们将这种方法应用于影响深远的表格学习领域,生成了一种强大的表格预测算法。

我们的 ICL 方法与标准的监督深度学习有根本的不同。通常情况下,模型是针对单个数据集进行训练的,根据手工设计的权重更新算法(如 Adam)对单个样本或批次更新模型参数。在推理时,学习到的模型被应用于测试样本。相比之下,我们的方法是在多个数据集上进行训练的,并且在推理时应用于整个数据集,而不是单个样本。在应用于真实世界的数据集之前,模型会预先在数百万个代表不同预测任务的合成数据集上进行一次性的预训练。在推理时,模型会接收一个未见过的数据集,该数据集同时包含标记过的训练样本和未标记过的测试样本,并通过单次神经网络前向传递在这个数据集上执行训练和预测。

image-20250403154839065

图1展示了所提出方法的概述。a部分展示了TabPFN预训练和使用的高级概述。b部分展示了TabPFN的架构。我们训练了一个模型来解决超过一亿个合成任务。我们的架构是标准Transformer编码器的改进版,适用于表格中遇到的二维数据。

图 1 和图 2 概述了我们的方法:

  1. 数据生成:我们定义了一个生成过程(称为我们的先验)来合成多样化的表格数据集,这些数据集在特征和目标之间具有不同的关系,旨在捕捉模型可能遇到的各种潜在场景。我们从生成过程中采样了数百万个数据集。对于每个数据集,我们会掩盖一部分样本的目标值,模拟一个监督预测问题。

  2. 预训练:我们训练了一个Transformer模型,即我们的 PFN,用来预测所有合成数据集的被掩盖目标值,给定输入特征和未被掩盖的样本作为上下文。这一步在模型开发过程中只进行一次,学习一个通用的学习算法,可以用于预测任何数据集。

  3. 真实世界预测:现在,经过训练的模型可以应用于任意未见过的真实世界数据集。训练样本被提供给模型作为上下文,模型通过 ICL 预测这些未见数据集的标签。

我们的方法也有理论基础,如参考文献 22 所述。它可以被视为对由合成数据集定义的先验的贝叶斯预测的近似。训练好的 PFN 将近似后验预测分布 p ( y t e s t ∣ X t e s t , X t r a i n , y t r a i n ) p(y_{test}|X_{test},X_{train},y_{train}) p(ytestXtest,Xtrain,ytrain),从而对 PFN 预训练期间使用的人工数据集的指定分布返回贝叶斯预测。

image-20250403155936425

图2展示了TabPFN先验的概述。(a)部分,对于每个数据集,我们首先对高级超参数进行采样。(b)部分,基于这些超参数,我们构建一个结构因果模型,该模型编码生成数据集的计算函数。计算图中的每个节点保存一个向量,每条边根据连接类型之一实现一个函数。在步骤1中,我们使用随机噪声变量生成初始化数据,这些数据被输入到图的根节点,并为每个待生成的样本在计算图中传播。在步骤2中,我们在图中随机采样特征和目标节点位置,分别标记为F和T。在步骤3中,我们提取在采样特征和目标节点位置的中间数据表示。在步骤4中,我们对提取的数据进行后处理。©部分,我们检索最终的数据集,并绘制特征对之间的交互,节点颜色表示样本的类别。

2.2 为表设计的架构

Transformer架构因其灵活性而成为深度学习和基础模型的首选架构。Transformer模型通过所谓的注意力机制在序列项之间结合信息,从而有效捕捉长距离依赖关系并学习数据中的复杂关系。尽管基于Transformer的模型可以应用于表格数据,但TabPFN解决了它们固有的两个关键限制。首先,由于Transformer是为序列设计的,它们将输入数据视为一个单一序列,而没有利用表格结构。其次,机器学习模型通常采用拟合-预测模式,即模型在训练集上拟合一次,然后用于多个测试数据集。然而,基于Transformer的ICL算法在单次传递中接收训练和测试数据,因此同时进行训练和预测。因此,当重用已拟合的模型时,必须对训练集重新进行计算。

为了更好地利用表格结构,我们提出了一种为表格中的每个单元格分配独立表示的架构,这一构想受到了文献2228的启发。我们的架构(如图1b所示)采用双向注意力机制,每个单元格首先关注其所在行(即其样本)中的其他特征,然后关注其所在列(即所有其他样本)中的相同特征。这种设计使架构对样本和特征的顺序具有不变性,并且能够更有效地进行训练,并且可以推广到比训练期间遇到的更大的表格,无论是在样本数量还是特征数量上。

image-20250403154839065

为了减少在拟合-预测设置中每个测试样本对训练集的重复计算,我们的模型可以将训练样本和测试样本的推理过程分开。这使我们能够对训练集执行一次ICL,保存结果状态,并将其用于多次测试集推理。在具有10,000个训练样本和10个特征的数据集上,我们经过优化的训练状态缓存技术在CPU上实现了约300倍的推理加速(从32秒缩短到0.1秒),在GPU上实现了6倍的加速。当特征数量增加10倍(达到100个)时,CPU上的加速倍数增加到800倍,GPU上的加速倍数增加到30倍。这些测量仅关注核心推理过程,不包括在“推理细节”部分中详细描述的预处理和集成步骤。GPU上的加速倍数较低是由于其大规模并行架构的利用率不足。

我们通过半精度计算层范数、使用Flash Attention、激活检查点和状态的顺序计算,进一步优化了架构的内存和计算需求。这些优化将内存需求减少了四倍,每个单元格的内存需求不到1,000字节。这使得在单个H100 GPU上能够对多达5,000万单元格的数据集(例如,500万行×10个特征)进行预测。

Flash Attention:是一种优化注意力机制的高效计算方法,通过分块矩阵乘法和梯度检查点技术减少内存占用,同时利用 GPU 的并行计算能力加速注意力计算,从而显著提高深度学习模型的内存效率和计算速度。具体而言,FlashAttention使用平铺和重计算等经典技术,将输入块从HBM(高带宽内存)加载到 SRAM (快速缓存),在SRAM上执行注意力操作,并将结果更新回HBM。 FlashAttention减少了内存读写量,从而实现了 2-4倍 的时钟时间加速。

[2205.14135] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

对于回归任务,我们采用分段常数输出分布,遵循文献2230的做法,这使我们的模型能够预测目标值的概率分布,而不是单一值,例如双峰分布。

2.3 基于因果模型的合成数据

TabPFN 的性能依赖于生成能够捕捉现实世界表格数据特征和挑战的合成训练数据集。为此,我们开发了一种基于结构因果模型(SCMs)的方法。SCMs 提供了一个正式框架,用于表示数据背后的因果关系和生成过程。通过使用合成数据而不是大量公共表格数据集合,我们避免了基础模型常见的问题,例如隐私和版权侵犯、训练数据被测试数据污染或数据可用性有限等问题。

如图 2 所示,我们的生成流程首先对高级超参数进行采样,例如数据集大小、特征数量和难度级别,以控制每个合成数据集的整体特性。在这些超参数的指导下,我们构建一个有向无环图,以指定数据集背后的因果结构。

image-20250403155936425

为了生成数据集中的每个样本,我们通过因果图的根节点传播随机生成的噪声,称为初始化数据。这些初始化数据通过从随机正态或均匀分布中采样生成,并在样本之间具有不同程度的非独立性。当这些数据在计算图的边上传播时,我们应用多种计算映射:小型神经网络(具有线性或非线性激活函数,例如 sigmoid、ReLU、取模、正弦函数)、用于生成分类特征的离散化机制以及用于编码局部规则依赖关系的决策树结构。在每条边上,我们添加高斯噪声,以引入生成数据的不确定性。我们在每个节点保存中间数据表示,以便后续检索。

在遍历因果图之后,我们在采样的特征和目标节点处提取中间表示,从而生成一个由特征值和关联目标值组成的样本。

通过在合成数据集中引入各种数据挑战和复杂性,我们为 TabPFN 创建了一个训练场所,使其能够开发出处理现实世界数据集中类似问题的策略。例如,考虑表格数据中常见的缺失值问题。通过在合成数据生成过程中让 TabPFN 暴露于具有不同模式和比例的缺失值的合成数据集中,模型学会了处理缺失值的有效方法,并能够推广到现实世界的数据集。我们应用后处理技术以进一步增强所学预测算法的真实性和鲁棒性挑战。这包括使用 Kumaraswamy 分布进行变形、引入复杂的非线性失真以及模仿离散化特征的量化。

通过这个生成过程,我们为每次模型训练创建了一个包含约 1 亿个合成数据集的大型语料库,每个数据集都具有独特的因果结构、特征类型和功能特性。

3 实验结果以及分析

评估的所有数据集均可在openml.orgkaggle.com上公开获得。我们在代码库中提供了脚本,可自动执行下载和评估数据集的过程。这些脚本包含数据集标识符,以及精确的数据拆分和处理过程。

3.1 TabPFN(PHE)

我们比较了 TabPFN 与 AutoGluon 1.0(文献 40)的性能。AutoGluon 1.0 将各种机器学习模型(包括我们的基线模型)组合成一个堆叠集成(stacked ensemble),调整它们的超参数,然后使用事后集成(Post Hoc Ensembling,PHE)生成最终预测。因此,它代表了一类与单一基线模型不同的方法。

为了进一步提升 TabPFN 的推理性能,我们在 TabPFN(PHE)中采用了 PHE方法,对来自扩展数据表 5 详述的搜索空间中固定组合的 TabPFN 配置进行集成。对于 TabPFN(PHE),我们首先使用保留验证(holdout validation)依次评估组合中的模型,直到达到时间限制。在所有模型都被评估一次后,我们使用新的数据划分重复保留验证,直到时间限制到达。然后,我们通过对所有评估过的 TabPFN 模型的预测结果进行加权算术平均来集成这些模型。我们使用贪婪集成选择(GES)在保留验证的预测数据上进行 25 次迭代来学习权重。最后,我们移除每个权重为零的模型,对所有剩余模型重新在全部数据上进行拟合,并返回它们预测结果的加权平均值。

image-20250405182815293

(a)TabPFN搜索空间;(b,c)基线搜索空间。

遵循 AutoML 的标准实践,我们采用 GES 是因为它的预测性能通常优于最佳单一模型。由于 TabPFN 具有 ICL(In-Context Learning)特性,我们预期 TabPFN 比传统训练算法更不容易过拟合训练数据;因此,我们选择重复保留验证(如 Auto-Sklearn 1 所用)而非重复交叉验证(如 AutoGluon 所用)。此外,由于 GES 通常会产生稀疏权重向量,我们预期在移除每个权重为零的模型后,最终集成的模型数量将少于其他集成方法(如 bagging)。因此,PHE 在提高 TabPFN 集成的推理效率方面优于其他集成方法。

3.2 定性分析

我们首先通过分析 TabPFN 在一些简单问题上的行为来建立直观理解,并剖析各种数据集特性对模型表现的影响。由于回归问题更容易可视化,我们在定性分析中重点关注这些问题。在图 3a 中,我们将 TabPFN 与多种标准预测器进行了比较,所有方法均使用默认设置。

image-20250403183338398

图3展示了TabPFN和一组基线模型在简单函数上的表现。在所有图中,我们用橙色表示真实情况(ground truth),用蓝色表示模型预测。a部分,每一列代表一个不同的简单函数,每个函数具有一个特征(沿x轴)和一个目标值(沿y轴)。TabPFN能够拟合多种不同的函数,包括带有噪声的函数。b部分,TabPFN能够直接对输出的分布进行建模,这通过预测双缝实验中的光强模式得到了例证。在这个实验中,观察了1000个光子的位置后,TabPFN能够预测出光强的分布模式。

线性回归只能自然地建模线性函数,这导致其预测简单且可解释,但在许多简单函数上会遭遇灾难性失败。多层感知机(MLP)在处理高度非平滑模式的数据集时表现较差,这在阶跃函数上尤为明显。相比之下,TabPFN 无需额外调整即可建模平滑或非平滑的函数类型,包括对阶跃函数的良好近似,尽管 TabPFN 是一种神经网络。作为基于树的方法的代表,CatBoost 只能拟合分段常数函数。尽管这会导致近似误差和不直观的预测,但它避免了灾难性失败。

CatBoost 是一种基于梯度提升的集成学习算法,主要用于处理分类和回归问题。它通过构建一系列决策树并使用梯度下降优化损失函数来提高模型的预测性能。CatBoost 特别擅长处理分类特征,能够自动对分类特征进行编码,并通过使用随机梯度提升来减少过拟合。其主要作用是提高模型的准确性和鲁棒性,特别是在处理具有大量分类特征的数据集时表现优异。

[1706.09516] CatBoost: unbiased boosting with categorical features

TabPFN 相比所有基线方法的主要优势在于其固有的建模不确定性能力,且无需额外成本。传统的回归方法输出单一实数值预测,而 TabPFN 返回目标分布,捕捉预测的不确定性。TabPFN 的这种不确定性建模能力不仅限于简单分布,还能处理复杂的多峰分布。图 3b 通过建模双缝实验中不同缝距和缝宽下光子到达探测器屏幕的光强密度展示了这一点。在这个经典实验中,光子通过两个狭缝,由于光的波动性干涉行为,形成了多峰强度模式。TabPFN 仅通过一次前向传递就能预测这些复杂的模式,仅需 1.2 秒。相比之下,传统方法如 CatBoost 需要训练多个不同分位数的分位数模型,并从这些预测中重建分布。即使专门为这项任务调整 CatBoost,其预测结果仍明显逊色于 TabPFN(见图 3b)。在默认设置下,CatBoost 需要 169.3 秒,并且结果进一步恶化。从定性上看,我们观察到 TabPFN 在预测非常低的密度时更为准确,并且与 CatBoost 相比产生的伪影更少。

3.3 定量分析

我们对 TabPFN 进行了定量评估,使用了两个数据集集合:AutoML Benchmark 和 OpenML-CTR23。这些基准测试包含多样化的现实世界表格数据集,经过精心策划,具有复杂性、相关性和领域多样性。我们从这些基准测试中使用了 29 个分类数据集和 28 个回归数据集,这些数据集最多包含 10,000 个样本、500 个特征和 10 个类别。我们还进一步评估了来自文献 1415 的其他基准测试套件,以及来自Tabular Playground Series的五场 Kaggle 比赛的数据集。

image-20250403184551281

image-20250403184614670

我们将 TabPFN 与最先进的基线方法进行了比较,包括基于树的方法(随机森林、XGBoost (XGB)、CatBoost、LightGBM)、线性模型、支持向量机 (SVMs) 和多层感知机 (MLPs)。

评估指标包括分类任务的 ROC AUC(接收者操作特征曲线下面积;One-vs-Rest)和准确率,以及回归任务的 R 2 R^{2} R2(决定系数)和负 RMSE(均方根误差)。分数按数据集进行归一化,其中 1.0 表示相对于所有基线方法的最佳性能,0.0 表示最差性能。

image-20250405170935101

image-20250405171016215

对于每个数据集和方法,我们进行了 10 次重复,使用不同的随机种子和训练-测试划分(90% 训练,10% 测试)。我们使用随机搜索和五折交叉验证来调整超参数,时间预算从 30 秒到 4 小时不等。所有方法都使用八个 CPU 核心进行评估,TabPFN 额外使用了一个消费级 GPU(RTX 2080 Ti;其他方法没有从中受益,详见扩展数据图 2d)。TabPFN 使用八个 NVIDIA RTX 2080 GPU 预训练了两周,允许对所有新数据集进行单次前向传递的上下文学习。这些适度的计算需求使类似的研究对学术实验室来说是可行的。image-20250405150630229

扩展数据图 2 展示了 TabPFN 和基线模型在其他基准数据集上的性能比较,以及在 GPU 支持下的性能表现:

a部分:在 Grinsztajn 中等规模基准测试中,具有分类特征的分类性能,涵盖7个数据集。
b部分:在 Grinsztajn 中等规模基准测试中,具有数值特征的分类性能,涵盖其15个数据集。
c部分:在 TabZilla 基准测试中的分类性能,该基准测试包含102个数据集,每个数据集的行数少于10,000行,具有500个特征和10个类别。为了进行5折交叉验证,移除了重复的数据集和每个类别样本少于5个的数据集。
d部分:CPU与GPU硬件的性能随时间变化比较:在分类测试基准上,使用八个CPU(CPU)与使用八个CPU和一个GPU(+GPU)运行我们最强的基线模型时的性能随时间变化情况。AutoGluon会自动决定使用哪些资源来训练哪些模型。对于CatBoost和XGB,我们指定这些模型应使用GPU进行训练。区间表示95%的置信区间。

3.4 与SOTA对比

图4a展示了TabPFN与XGBoost、CatBoost和随机森林的调优和默认配置相比所具有的强大开箱即用性能。

对于分类任务,TabPFN在默认设置下的归一化ROC AUC值为0.939,比最强的默认基线CatBoost的0.752高出0.187;在调优设置下,TabPFN的归一化ROC AUC值为0.952,比CatBoost的0.822高出0.13。

对于回归任务,TabPFN在默认设置下的归一化RMSE值为0.923, 比CatBoost的0.872高出0.051;在调优设置下,TabPFN的归一化RMSE值为0.968,比CatBoost的0.875高出0.093。图4b展示了按数据集进行的比较。尽管在某些数据集上CatBoost的表现优于TabPFN,但TabPFN在大多数数据集上胜出。

image-20250405160920047

图4:在我们的测试基准上对 TabPFN 进行比较,这些基准包含最多 10,000 个样本和 500 个特征的数据集。性能在聚合前按数据集进行了归一化,并使用所有基线进行表示;区间代表 95% 的置信区间。Wilcoxon P 值是指两侧 Wilcoxon 符号秩检验 P 值。a 部分,TabPFN 和我们基线的默认版本以及调优版本的平均性能。所有方法都针对 ROC AUC 或 RMSE 进行了调优,因此降低了次要指标的代表性。LGBM,即 LightGBM;MLP,即多层感知机;SVM,即支持向量机;RF,即随机森林;CB,即 CatBoost;XGB,即 XGBoost;Lin,即分类任务的逻辑回归和回归任务的线性回归。右侧的图表显示了所考虑的最强基线的放大分析。b 部分,TabPFN 与其最强基线 CatBoost 的按数据集比较。每个点是一个数据集上的平均得分。c 部分,超参数调优对所考虑方法的影响。x 轴显示了使用算法进行拟合和预测所需的平均时间。

图4c展示了TabPFN和基线模型的性能如何随着超参数搜索时间的增加而提高。TabPFN的默认配置在分类任务平均耗时2.8秒,回归任务耗时4.8秒的情况下,性能优于所有基线模型,即使这些基线模型经过4小时的调优——分别实现了5,140倍和3,000倍的速度提升。我们在扩展数据表1和2中展示了更多指标的比较。

如扩展数据图2所示,与我们的主要基准测试类似,TabPFN在文献 1415 的基准测试中大幅优于所有基线模型。文献14的基准测试尤其值得注意,因为在该基准测试中,基于树的方法之前被认为表现出色。此外,我们在扩展数据表6中展示,在最近完成的Tabular Playground Series中,对于所有五个训练样本少于10,000个的Kaggle竞赛,TabPFN的默认配置均优于CatBoost的默认配置。

4 可解释性

除了具有强大的预测性能外,TabPFN 还展现了关键的基础模型能力,例如数据生成、密度估计、学习可复用嵌入和微调。我们通过在德国信用数据集(包含信用风险信息)和 mfeat-factors 数据集(基于表格表示对手写数字进行分类)上的概念验证实验来展示这些能力。

如图6a所示,TabPFN 可以估计数值特征的概率密度函数,以及分类特征的概率质量函数。计算样本密度可以实现异常检测,以识别诸如欺诈、设备故障、医疗紧急情况或低质量数据等问题。

TabPFN 还可以合成模仿现实世界数据集特性的新表格数据样本,如图6b所示。这使得数据增强或隐私保护数据共享等应用成为可能。

image-20250405195044261

图6展示了TabPFN作为表格基础模型的应用。a和b部分,在德国信用数据集上,我们进行了数据密度估计(a)和新合成样本的生成(b)。c部分,我们展示了在手写数字数据集(mfeat-factors)上,表明我们学习的嵌入是每个样本的有用表示,不同类别形成不同的聚类。d部分,我们展示了针对特定任务集对TabPFN进行微调的应用。在包含各种正弦曲线的数据集上进行微调(顶部),我们看到该模型对另一个正弦曲线数据集的预测更为准确。

TabPFN 的架构产生了有意义的特征表示,这些表示可以用于下游任务,例如数据插补和聚类。我们在图6c中提取并可视化了来自 mfeat-factors 数据集的所学嵌入,与原始数据相比,在前两个主成分上显示了改进的类别分离。

此外,我们展示了 TabPFN 通过在相关数据集上进行微调来提高性能的能力。与基于树的方法不同,TabPFN 的神经架构支持对特定数据集类别的微调。我们使用不同偏移量的正弦曲线数据集进行了概念验证实验。图6d显示了一个微调结果的例子。我们的分析(见扩展数据图4)表明,即使在微调和测试任务之间的标签差异显著时,TabPFN 也能成功地迁移知识,并且随着分布变得更加相似,性能会得到提高。例如,这可以实现对医学研究中的一系列数据集进行微调,以获得改进的医学诊断通用模型。

image-20250413193627919

扩展数据图 4 | 在二维正弦曲线数据集上微调 TabPFN。

(a) 不同偏移的二维正弦曲线数据集示例。

(b) 在随机训练-测试偏移下进行 50 次微调的损失曲线。颜色表示训练和测试之间的偏移。TabPFN 显示出正向迁移,即在分布更相似时表现更好。对于 π 的数据偏移,测试集需要预测与微调数据相反的标签。然而,即使在这种情况下,TabPFN 仍然能够泛化。

5 结论

TabPFN 是表格数据建模领域的一项重大变革,它利用上下文学习(ICL)自主发现了一种高效的算法,该算法在最多 10,000 个样本和 500 个特征的数据集上的表现优于传统的人工设计方法。这种向基于合成数据训练的基础模型的转变,为各个领域的表格数据分析开辟了新的可能性。

未来的研究方向包括扩展到更大的数据集、处理数据漂移、研究相关表格任务的微调能力以及理解我们方法的理论基础。未来的工作还可以探索创建专门的先验来处理时间序列和多模态数据等数据类型,或专门的模态如心电图(ECG)、神经影像数据和遗传数据。随着表格数据建模领域的不断发展,我们相信像 TabPFN 这样的基础模型将在赋能研究人员方面发挥关键作用。为了促进 TabPFN 的广泛应用,在“用户指南”部分我们讨论了如何有效地使用它。

参考

TabPFN: 预训练表格基础模型 - 知乎

Nature: 准确预测小样本数据的表格基础模型-CSDN博客

[2207.01848] TabPFN: A Transformer That Solves Small Tabular Classification Problems in a Second

<think>嗯,用户问的是TabPFN运行时为什么要连外网。我需要先了解TabPFN是什么。TabPFN是一种基于Transformer的先验拟合网络,主要用于表格数据的分类任务。它的特点是训练速度快,适合小数据集。 接下来,用户的问题涉及到联网需求。可能的原因有几个:首先,模型可能在运行时需要下载预训练的权重文件。很多机器学习模型,尤其是大型的,如BERT、GPT等,都会在首次使用时下载预训练参数。如果TabPFN也是这样,那用户第一次使用时会需要网络连接。 其次,是否有在线服务的依赖?有些框架或工具会在运行时调用外部API,例如获取实时数据或验证许可证。如果TabPFN的设计中包含了这些服务,那么联网就是必须的。需要确认TabPFN的文档是否有相关说明。 然后,考虑安全或授权机制。有些软件需要联网进行用户验证,确保许可证有效,或者防止盗版。如果TabPFN采用此类保护措施,联网也是必要的。 另外,可能用户提到的“连外网”是指访问特定的服务器,比如更新模型或获取依赖项。这时候,即使本地有部分数据,也可能需要连接服务器获取最新版本或补充资源。 不过,也有可能用户遇到的是配置问题,比如误报的网络需求,或者某些库在导入时默认尝试访问网络。这时候需要检查代码或文档,确认是否是设计如此还是环境问题。 需要验证TabPFN的官方文档或源码。例如,查看GitHub仓库的安装说明,是否有提及自动下载模型权重。如果模型文件较大,通常会托管在远程服务器,首次运行时会下载到本地缓存,之后可能不再需要网络。 另外,可能依赖项中的某些库需要网络,比如nltk需要下载语料库,但这种情况通常会有明确的提示,与TabPFN本身的联网需求不同。 总结可能原因:预训练模型下载、许可证验证、在线服务依赖、或配置问题。建议用户检查文档,查看是否有离线模式,或者提前下载模型文件。如果必须联网,可能需要配置代理或确保网络畅通。</think>TabPFN运行时需要连接外网的主要原因与它的技术架构和功能特性有关,以下是逐步分析: --- ### **1. 预训练模型下载** **核心机制**:TabPFN是基于大规模预训练的概率函数网络(Prior-Data Fitted Networks),其核心能力来源于预训练的权重文件。 - **首次运行时**:若本地缓存中没有预训练模型文件(通常数百MB),系统会从官方服务器(如Hugging Face Hub或AWS)自动下载。 - **验证示例**: ```python from tabpfn import TabPFNClassifier classifier = TabPFNClassifier() # 此处触发模型下载 ``` --- ### **2. 依赖库的远程加载** **框架设计**:TabPFN依赖PyTorch和Transformer类库,部分组件可能动态加载: - **自动版本检查**:确保使用的库版本与模型兼容(需访问PyPI或GitHub)。 - **特定功能扩展**:如集成特征工程工具时可能需临时获取代码。 --- ### **3. 学术验证机制(可选)** **研究属性**:作为学术导向工具,TabPFN可能内置匿名化数据统计功能: - **使用情况上报**:统计模型调用频率、任务类型(可通过`disable_statistics=True`关闭)。 --- ### **4. 网络依赖的解决方案** 若需离线运行,可通过以下方式处理: 1. **提前下载模型**: ```bash python -c "from tabpfn import TabPFNClassifier; TabPFNClassifier()" # 触发下载后缓存至本地 ``` 2. **指定本地路径**: ```python classifier = TabPFNClassifier(model_path='/your/local/path/tabpfn.pt') ``` 3. **防火墙配置**:允许`tabpfn`进程仅访问模型托管域名(如huggingface.co)。 --- ### **关键注意事项** - 后续调用会复用本地缓存的模型,无需重复下载。 - 企业部署时建议通过内部镜像托管模型文件(需符合开源协议)。 如需进一步排查具体网络请求,可使用`Wireshark`或`tcpdump`监控流量目标域名。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值