ICLR 2025 | 超越梯度提升树!南大提出面向表格数据的表示学习方法,预测任务取得新SOTA...

aa8fde6a6a7b7bf2909997b5e5c4d553.gif

©PaperWeekly 原创 · 作者 | 尹怀鸿

单位 | 南京大学硕士研究生

研究方向 | 表格数据,多模态学习

近年来,深度学习在图像、语音、自然语言处理等领域大放异彩,但在一个看似简单却至关重要的场景——表格数据预测任务中,深度学习的能力却难以预约如梯度提升决策树(GBDT)等经典机器学习方法。

表格数据预测即最经典的机器学习任务,数据以行列结构组织,给定样本的向量表示,需要对其标记进行预测,在医疗记录、金融风控、电商推荐等行业中,表格数据仍作为核心。

随着深度学习在相关领域中的应用,研究者们也尝试思考,是否深度神经网络的能力能拓展至表格预测领域。近年来,有大量工作,从数据处理技巧、网络结构设计、大语言模型融入等多个方面进行了尝试,但深度神经网络的效果依然有限,或仅能在部分领域的数据上相比于树模型取得突破。

梯度提升树的出色的性能给予研究者启发:也许基于传统方法进行改进,能够让深度方法在传统方法上锦上添花,在表格数据上实现能力跨越。

在这一思路的指引下,南京大学团队从一个经典的可微 K 近邻算法——近邻成分分析(Neighbourhood Component Analysis,NCA)出发,通过不断加入深度学习的技术,提出 ModernNCA 方法,用便捷的形式取得深度学习模型在表格数据分类、回归任务上的性能突破,在 300 个数据集上展现出优越于其他深度方法的性能。

78bfa70d86c8e723ef036ccb1e9e0ae2.png

论文标题:

Revisiting Nearest Neighbor for Tabular Data: A Deep Tabular Baseline Two Decades Later

论文链接:

https://openreview.net/forum?id=JytL2MrlLT

工具包链接:

https://github.com/qile2000/LAMDA-TALENT

ModernNCA代码链接:

https://github.com/qile2000/LAMDA-TALENT/blob/main/LAMDA_TALENT/model/models/modernNCA.py

061fac2b7c64ffd002f644d7eac8f315.png

回顾近邻成分分析

K 近邻算法(K-Nearest Neighbor,KNN)是一种简单而直观的非参数方法,广泛应用于分类和回归任务。KNN 的核心思想是通过度量样本之间的距离,从训练数据中寻找与给定样本最相似的 K 个邻居,并基于这些邻居的信息来预测目标值。

然而,KNN 的一个问题在于其基于原始特征空间进行距离度量,在高维数据中容易受到“维度灾难”的影响,导致性能下降。

为了解决这个问题,Jacob Goldberger 等人在 2004 年提出了近邻成分分析(NCA)[1],NCA 通过学习一个映射矩阵,使得在新的特征空间中,同一类的样本彼此靠近,而不同类的样本则远离,从而增强 KNN 算法的分类效果。

94b8722cd38611f670ca370bb8d566ea.png

▲ 图1:NCA 进行表格预测思路

NCA 的基本思想如图 1 所示。在图中,绿色的方框是原始的特征空间,圆形、方形等各种形状图案是不同类别的样本。NCA 将样本映射到一个新的空间中(图中橙色的方框),在这个空间中,样本之间的距离能够更好地反映其类别关系。

具体来说,NCA 通过最大化近邻的相似度来优化映射矩阵。假设数据集为,其中 是第 个样本的特征向量, 是其对应的标签。NCA 的目标是通过对样本点进行线性变换,学习一个映射矩阵 ,使得同类样本之间的距离尽可能小,而异类样本之间的距离尽可能大。NCA定义了样本 位于 近邻的概率为:

7496e133f8c0fb32126ba368d0c95627.png

以此,样本 属于类别 的概率为:

a41da02127c54daef55c8e8b9adf57c5.png

NCA 通过最大化所有训练样本的 之和,学习映射矩阵 ,在测试阶段,在映射矩阵 L 投影的空间运行 K 近邻算法,对样本进行分类。

然而,尽管 NCA 提出于二十年前,并且提出的时候论文就将其应用于 iris 等表格数据分类任务,但由于 NCA 的预测能力远远不及 RandomForest、XGBoost 等机器学习方法,后续 NCA 在研究者视野中逐渐淡去,在 sklearn 工具包中也仅作为一种可视化降维方法出现。

本文重新分析并发现 NCA 的潜力,通过深度学习技术,对 NCA 进行一系列改进,学习面向表格数据高质量的特征表示,不仅能让其性能大幅增强,而且相较于其他表格深度学习方法,在时间,性能,内存消耗上有着更优秀的平衡。

3a888b2439d798137419243ba767831a.png

改进步骤 1

原始 NCA 使用线性投影,且局限于分类场景。我们对预测公式进行修改,假设 在分类任务中是 one hot 形式,在回归任务中是数值的形式,对于样本预测的标签值为:

00e3ecc09ffa66939dd43f907d921530.png

——公式(1)

上述的公式中 代表变换的形式,在原始 NCA 中使用的是线性层。基于此公式,通过对近邻样本 label 的加权,在分类任务中 是样本 的预测概率分布,在回归任务中 是样本 的预测值。

于是我们对分类任务使用负对数似然损失(negative log-likelihood loss),对回归任务使用均方误差损失进行训练。注意,这不同于原始的 NCA,用最大化概率和进行分类任务训练,我们的分类任务损失相当于最大化对数概率之和。

在预测策略上,原始的 NCA 学习了映射变换 后,在映射空间运用 KNN 算法寻找近邻进行预测。而我们采用的是直接使用公式(1)中得到的预测值进行预测。

此外,原始的 NCA(scikit-learn 实现)默认进行降维,即投影矩阵 要求满足 ,我们去除了这一限制,并将优化器由 L-BFGS 改变为 SGD。

我们发现使用了这些改进后,NCA 在预测性能上有了显著的提升,即便只有一个线性映射层,其预测性能也能比肩 MLP。我们把这一改进后的 NCA 版本称为 L-NCA。

d63fe9d6ff0dfcef47109e8459b4d842.png

改进步骤 2

尽管线性版本的 L-NCA 已展现出潜力,但其表达能力仍受限于线性映射。为了充分释放深度学习的优势,研究者进一步引入现代深度学习技术,提出了ModernNCA(M-NCA),核心改进主要包含以下两点:

3.1 深度非线性架构

原始的 L-NCA 仅通过线性映射提取特征,而 M-NCA 将线性投影 替换为多层非线性模块。具体地,每个模块由批归一化(batch normalization)、线性层、ReLU 激活、Dropout 和另一个线性层构成,数学形式为:

4d3d490b01e5c68bc0c6e7b25e63efc6.png

通过叠加多个此类模块,模型能够捕捉复杂的特征交互。此外,对于数值型特征,M-NCA 引入了 PLR(Periodic-Linear-ReLU)编码,将数值映射到高维空间,增强非线性表达能力。

3.2 加入采样策略

传统 NCA 需计算目标样本与全部训练数据的距离,当训练集样本量很大(比如上百万)时候,计算开销巨大。为此,M-NCA 提出随机近邻采样(Stochastic Neighborhood Sampling, SNS)策略进行训练:

在训练阶段,M-NCA 每个批次仅随机采样部分训练数据(比如 30%)作为邻域候选,以降低计算量和显存消耗。在推理阶段,M-NCA 仍使用全体训练数据搜索近邻,保证预测精度。通过实验发现,SNS 不仅能显著加速训练,还能提升模型的泛化性能。

f35b790ac2329e85f832f2a118d9fc0d.png

实验结果

4.1 主实验结果

研究团队在包含 300 个表格数据集(180 个分类数据集和 120 个回归数据集)的大规模基准测试中验证了 ModernNCA 的性能 [2]。

图 2 展示了不同表格数据方法的平均排名以及两两 Wilcoxon-Holm 检验的结果,实验结果显示,ModernNCA 在分类任务的平均准确率与回归任务的 RMSE(均方根误差)上均显著优于现有深度模型,并与当前最优的树模型 CatBoost 性能相当。

2050d1bd1c4cf0eb8fb037d3bb74cb3d.png

▲ 图2:表格数据方法平均排名的临界图

图 3 对图 2 中的方法的性能,运行时间和显存消耗进行比较,纵轴比较了不同方法的训练时间,横轴比较了不同方法的平均排名,圆圈的半径表示训练是消耗的显存。从图中可见,相较于其他的深度方法,ModernNCA 具备优秀的性能同时保持了合理的显存占用和高效的训练时间。

8c671e2b891d1ed0fbaed672e97883d1.png

▲ 图3:表格数据方法性能、运行时间和显存消耗对比

4.2 消融实验

在 27 个分类数据集和 18 个回归数据集上,我们对改进的各个组件的有效性进行评估:

从表 1 可以看出,从原始的 NCA(表示为 NCAv0)中不断进行:不限制维度提升,优化器改进为 SGD,使用负对数似然损失,Soft 的近邻预测方式等改进,NCA 的平均排名持续下降,性能不断提升,即便只有一层线性映射,已经超过 MLP 的性能。

2e3045e483990441a18bb83139de74c7.png

▲ 表1:NCA 加入不同改进组件后的平均排名

从表 2 可以看出,允许学习更深层的映射后,ModernNCA 具有更低的排名;此外,用带 batch normalization 的 MLP 学习 要好于其他的学习方式。

93269b747abe84073739d1db0643dc6a.png

▲ 表2:映射层 不同结构的平均排名

从图 4 可以看出,ModernNCA 采用了随机近邻采样策略后,在 30-50% 的采样率产生了好于全量样本作为近邻训练的结果。

511bbbeae94d402edd5097a65285db7e.png

▲ 图4:不同采样比例的平均排名

4.3 可视化结果

以 AD 数据集为例,我们使用 TSNE 对不同方法学习的表征进行对比,如图 5 所示。

45b31f2163adfac7580887314bcfef48.png

▲ 图5:学习表征可视化结果

可以看出,相较于原始特征,不同方法都学到了更加易于预测的特征空间。使用对比学习的方法(TabCon)使得同类的样本聚集成一个簇,难以处理难分样本。而 ModernNCA 会将同类的样本聚类成多个簇,保证相似的样本位置相近。ModernNCA 的机制能更好地学习样本间的局部关系,适应表格数据的特性。

5ecbec30b2125d96d4ee8930cfe08d10.png

总结

ModernNCA 通过融合经典近邻思想与深度学习技术,成功让二十年前的 NCA 算法焕发新生。其在 300 个数据集上的实验表明:ModernNCA 可以作为深度表格预测的一个强大的基线方法,相较于梯度提升树和其他深度表格方法展现出了强劲的性能。这一方法也启示研究者:对传统方法的现代化改造,可能是解锁深度学习潜力的关键钥匙。

outside_default.png

参考文献

outside_default.png

[1] Goldberger, Jacob, et al. 'Neighbourhood components analysis.'Advances in neural information processing systems 17 (2004).

[2] Ye, Han-Jia, et al. 'A closer look at deep learning on tabular data.' arXiv preprint arXiv:2407.00956 (2024).

更多阅读

3930ca867a40395b03393e303a9ca4da.png

41e13c602b6ea2450f88898664cba346.png

b86ec334426ea8cbacc05721b8957cbd.png

221d43ae8e50933096178084bf802f07.gif

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算

📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿

258f8cd678b21157c85486946aeaabde.png

△长按添加PaperWeekly小编

🔍

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

·

·

·

c7ff9732c423a2cdbf0a452c6fa6143d.jpeg

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值