Nature: 准确预测小样本数据的表格基础模型

图片

文章信息

图片

论文题目为《Accurate predictions on small data with a tabular foundation model》,文章来自Nature正刊,论文提出了一种名为Tabular Prior-data Fitted Network(TabPFN)的表格基础模型,用于在小到中型表格数据上进行准确的预测,并在多项基准测试中显著优于现有的方法,如梯度提升决策树(如CatBoost、XGBoost等)。

图片

摘要

图片

表格数据(以行和列组织的结构化数据)在生物医学、材料科学、经济学等领域广泛应用,但传统方法(如梯度提升决策树)在处理小样本数据集时存在知识迁移能力弱、不确定性建模不足等问题。本文提出了一种基于Transformer架构的表格基础模型——Tabular Prior-data Fitted Network (TabPFN),通过在大规模合成数据上进行预训练,实现了对小到中型表格数据(样本量≤10,000,特征数≤500)的高效预测。实验表明,TabPFN在分类和回归任务中显著优于现有方法,例如在分类任务中以2.8秒的推理时间超越经过4小时调优的CatBoost集成模型(速度提升5,140倍),同时具备生成数据、密度估计和可解释性等基础模型特性。这一研究为表格数据分析提供了一种端到端学习的范式,有望加速科学发现与决策优化。

图片

引言

图片

表格数据的重要性与挑战:表格数据是科学研究与工业应用的核心载体。例如,在药物发现中,化学性质表格用于预测分子活性;在气候科学中,气象观测表格用于建模气候变化。然而,表格数据具有高度异质性:特征类型多样(布尔型、分类型、连续型)、存在缺失值与异常值,且不同数据集间特征语义差异显著(如“温度”在材料科学与气象学中的含义不同)。传统方法(如梯度提升决策树)虽在单数据集上表现优异,但难以实现跨数据集的知识迁移,且无法建模预测不确定性。

深度学习的局限与机遇:尽管深度学习在图像和自然语言处理领域取得突破,但在表格数据上的应用仍受限。原因包括:(1)表格数据缺乏空间或序列的局部相关性;(2)小样本场景下模型易过拟合;(3)传统树模型在计算效率上具有优势。然而,基于Transformer的大模型通过上下文学习(In-Context Learning, ICL)展现了强大的算法学习能力,为表格数据分析提供了新思路。

本文提出TabPFN,首次将基础模型范式引入表格数据分析,核心创新包括:

(1)基于合成数据的预训练框架:通过结构因果模型生成百万级合成数据集,覆盖真实数据的多样性挑战(如缺失值、异常值、无关特征);

(2)双向注意力架构:设计针对表格结构的Transformer层,实现样本与特征的双向交互,支持快速推理与状态缓存;

(3)多功能基础模型特性:除预测外,支持数据生成、密度估计、特征嵌入与微调,为下游任务提供统一工具。

图片

方法论

图片

TabPFN的流程分为三个阶段(图1):

(1)合成数据生成:基于结构因果模型(SCM)生成多样化数据集,模拟真实数据分布与挑战(如缺失值、异常值);

(2)预训练:使用Transformer模型学习从上下文(训练集)到目标(测试集)的映射,优化交叉熵损失;

(3)推理:将预训练模型应用于真实数据集,以前向传播完成训练与预测。

图1 模型框架图

1.合成数据生成

TabPFN的核心创新之一是通过大规模合成数据预训练模型,使其能够泛化到多样化的真实表格数据任务。本节详细阐述合成数据的生成机制,包括结构因果模型(Structural Causal Models, SCMs)的构建计算图传播后处理技术,确保生成数据覆盖真实场景中的复杂性(如缺失值、异常值、无关特征等)。步骤如图2:

(1)因果图采样:每个数据集的因果图节点数从对数均匀分布中采样,确保生成不同复杂度的图结构(如简单线性关系或深层非线性依赖)。采用偏好连接机制(Preferential Attachment)生成随机无标度网络。允许生成多个独立子图后合并,模拟真实数据中存在的无关特征(即与目标变量无因果路径的特征)。

(2)节点类型与初始化

根节点:代表数据生成过程的起点,其值通过正态分布、均匀分布、混合分布初始化。

中间节点与叶节点:通过计算图传播生成,代表衍生特征或目标变量。

(3)计算图传播与特征生成:每个节点的值通过其父节点的值经计算模块处理后生成,模块类型包括:非线性变换模块、离散化模块、决策树模块、噪声注入。

(4)特征与目标选择:完成计算图传播后,从图中选择特征列与目标列。

(5)后处理增强现实性:为逼近真实数据分布,对生成数据施加多种扰动:非线性扭曲、量化与分箱、缺失值注入、异常值生成。

图2 合成数据生成

2.模型架构设计

TabPFN的模型架构针对表格数据的特性进行了深度优化,在标准Transformer编码器的基础上,引入了双向注意力机制行列解耦计算高效内存管理策略,实现了对表格结构的显式建模与高效推理。以下从核心模块、优化策略与输出建模三方面详细解析其设计原理。

(1)双向注意力机制:传统Transformer将输入视为单一序列处理,忽略了表格数据中行(样本)与列(特征)的二维结构。TabPFN提出双向注意力机制,将表格分解为行与列两个正交维度,分别进行注意力计算,充分捕捉样本间与特征间的依赖关系。特征注意力(Feature Attention):在同一行(样本)内,建模不同特征间的交互关系。例如,在医疗数据中,“年龄”与“血压”可能具有非线性关联。样本注意力(Sample Attention):在同一列(特征)内,建模不同样本间的全局依赖。例如,在金融风控中,相似客户的信用评分可能相互影响。双向注意力堆叠每个Transformer层依次执行特征注意力与样本注意力,形成“行→列”或“列→行”交替的计算流(图1b),不同层的注意力头共享部分参数,减少模型复杂度,增强泛化能力。

(2)针对表格数据规模大、特征异构性强的特点,TabPFN在计算与内存层面进行了多项优化:行列解耦计算、混合精度训练、动态内存分配、状态缓存与复用。
(3)输出建模:TabPFN不仅预测目标值,还建模其不确定性,支持回归与分类任务的概率输出。

图片

实验

图片

1.定量分析

在两个数据集上定量评估了TabPFN:AutoML Benchmark36和OpenML-CTR2337。这些基准包括不同的现实世界表格数据集,针对复杂性、相关性和领域多样性进行了精心策划。根据这些基准测试,使用了29个分类数据集和28个回归数据集,这些数据集有多达10000个样本、500个特征和10个类别。将TabPFN与最先进的基线进行了比较,包括基于树的方法(随机森林38、XGBoost(XGB)、CatBoost9、LightGBM8)、线性模型、支持向量机(SVM)和MLPs。评估指标包括ROC、AUC和分类准确性,以及R2和负RMSE用于回归,结果如图3。

图3 实验结果

2.评估各种数据属性

在图4a、b中,展示了TabPFN对数据集特征的鲁棒性,而传统上基于神经网络的方法很难处理这些特征。图5a提供了TabPFN在各种数据集类型上的性能分析。首先,添加无信息特征(来自原始数据集的随机混洗特征)和异常值(将每个单元格以2%的概率乘以0和异常值因子之间的随机数)。结果表明,TabPFN对无信息特征和异常值具有很强的鲁棒性,这对神经网络来说通常很困难,正如MLP基线所示。其次,尽管丢弃样本或特征会损害所有方法的性能,但使用一半的样本,TabPFN的性能仍然与使用所有样本的次佳方法一样好。

在图4b中,将测试数据集划分为子组,并对每个子组进行分析。我们根据分类特征的存在、缺失值、样本数量和数据集中的特征数量创建子组。样本数和特征数子组被拆分,使得三分之一的数据集落入每个组中。可以看到,与其他方法相比,这些特征都没有强烈影响TabPFN的性能。

图4 实验结果

图片

结论

图片

TabPFN的合成数据生成框架通过结构因果模型多样化后处理技术,系统性地模拟了真实表格数据的复杂性。这一方法不仅规避了真实数据收集中的隐私与版权问题,还通过控制生成参数实现了数据分布的全面覆盖,为模型预训练提供了高质量的算法学习环境。实验表明,由此生成的合成数据能够有效支持模型在未知真实任务上的强泛化能力。

<think>嗯,用户现在问的是关于TabPFN的源码。首先,我需要回顾一下之前的对话历史。之前用户询问了TabPFN是否可以用于预测船舶载货量,我详细分析了可行性和实现路径,包括分类改造和回归扩展的方案。现在用户进一步想要了解源码,这说明他们可能已经对TabPFN的应用有了一定的兴趣,现在想深入理解其内部机制,或者可能想要自己修改或扩展模型。 接下来,我需要确定用户的真实需求。用户可能是一名研究人员、数据科学家或者开发者,想要通过查看源码来更好地理解模型的实现细节,或者进行二次开发。因此,回应的内容需要详细但不过于技术化,同时要考虑到用户可能的不同背景。 用户可能想知道如何获取源码,源码的结构,核心模块的功能,以及如何参与贡献。此外,他们可能对预训练机制、快速推理实现以及如何自定义模型感兴趣。需要确保回答覆盖这些方面,并且提供相关的资源链接,比如GitHub仓库和论文,以便用户进一步查阅。 在组织回答时,应该分点说明,结构清晰。首先介绍源码的获取方式,然后概述整体架构,接着详细讲解核心模块,再讨论预训练和推理的机制,最后提供二次开发的建议和注意事项。需要避免使用过于专业的术语,或者在使用时进行解释,确保用户能够理解。 同时,需要注意用户可能对某些技术细节不太熟悉,比如Transformer架构或贝叶斯推断,因此在解释这些部分时需要简明扼要,必要时提供类比或简单解释。此外,提醒用户关于硬件依赖和社区支持的问题,帮助他们评估是否适合深入源码研究。 最后,确保所有链接正确无误,并且建议用户参考官方文档和论文,以便他们获得最准确和最新的信息。如果有社区或论坛的链接,也可以提供,方便用户寻求帮助和讨论问题。</think>以下是关于TabPFN源码的核心信息及解读指南,按技术层级逐步展开: --- ### **1. 源码获取与官方资源** - **代码仓库**: TabPFN官方代码托管在GitHub: [https://github.com/automl/TabPFN](https://github.com/automl/TabPFN) (注意:需遵守Apache-2.0开源协议) - **关键分支**: - `main`:稳定版实现 - `experimental`:包含未正式发布的扩展功能(如回归任务原型) - **依赖项**: 源码基于PyTorch实现,需配合`transformers`库(用于预训练模型加载)。 --- ### **2. 源码结构概览** ```plaintext TabPFN/ ├── tabpfn/ │ ├── models/ # 模型定义核心 │ │ ├── tabpfn.py # TabPFN主网络结构 │ │ └── transformer.py # 自定义Transformer层 │ ├── priors/ # 预训练先验知识生成器 │ │ ├── diffusions.py # 数据扩散过程定义 │ │ └── prior_utils.py # 先验采样工具 │ ├── scripts/ # 训练与评估脚本 │ ├── utils/ # 数据预处理工具 │ └── ... # 其他辅助模块 ├── experiments/ # 实验配置与日志 └── tests/ # 单元测试 ``` --- ### **3. 核心模块解析** #### **(1) 模型架构 (`models/tabpfn.py`)** - **核心类**:`TabPFN` 继承自`nn.Module`,包含: - **输入编码层**:将分类/数值特征映射到统一嵌入空间 ```python self.input_encoder = MixedInputEncoder(feature_embedding_sizes) ``` - **Transformer骨干网络**:多层堆叠的注意力机制 ```python self.transformer = TransformerModel( d_model=config.d_model, nhead=config.nhead, num_layers=config.num_layers ) ``` - **输出头**:分类概率预测(`classification_head`)或回归预测(实验分支) #### **(2) 预训练机制 (`priors/diffusions.py`)** - **关键逻辑**:通过**合成数据生成器**创建预训练任务 ```python def generate_batch(batch_size): # 生成虚拟特征X和标签y,模拟真实数据分布 X = sample_features(prior_distributions) y = apply_synthetic_function(X) + noise return X, y ``` - **先验知识融合**:在微调阶段,模型利用预训练的归纳偏置快速适应新数据集。 #### **(3) 快速推理 (`models/tabpfn.py:forward`)** - **批量预测优化**:通过缓存中间结果加速推理 ```python def forward(self, x, cache=None): if cache is None: x_emb = self.input_encoder(x) features = self.transformer(x_emb) else: # 使用缓存的嵌入结果 features = cache return self.classification_head(features) ``` --- ### **4. 二次开发指引** #### **(1) 扩展回归功能** - **修改位置**: 在`TabPFN`类中添加回归输出头: ```python class TabPFN(nn.Module): def __init__(self, ..., regression_output_dim=1): super().__init__() self.regression_head = nn.Linear(d_model, regression_output_dim) def forward(self, x, task_type='classification'): if task_type == 'regression': return self.regression_head(features) ``` #### **(2) 自定义先验分布** - **编辑路径**:`priors/diffusions.py` 修改`sample_features()`中的分布参数,例如: ```python def sample_features(): # 原代码使用均匀分布 # 可改为船舶领域相关的分布(如船舶尺寸的Weibull分布) feature = weibull_dist.sample() ``` --- ### **5. 关键代码片段示例** #### **Transformer层改进(位置编码)** ```python class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:x.size(0), :] return x ``` --- ### **6. 运行环境要求** - **硬件**:官方预训练模型需**GPU支持**(建议显存 ≥ 16GB) - **软件依赖**: ```plaintext Python >= 3.8 PyTorch >= 1.12 transformers >= 4.18 ``` --- ### **7. 注意事项** 1. **预训练权重加载**:官方提供的`TabPFNClassifier`会自动下载预训练参数(约2.3GB) 2. **计算效率限制**:由于使用全注意力机制,输入特征数建议 ≤ 100列 3. **社区支持**:当前开发活跃度较低,重大修改建议fork后自主维护 --- 如需深入理解具体实现细节,推荐结合原始论文阅读: [TabPFN: A Transformer That Solves Small Tabular Classification Problems in a Second](https://arxiv.org/abs/2207.01848)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

当交通遇上机器学习

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值