Revisiting Deep Learning Models for Tabular Data
arxiv [Submitted on 22 Jun 2021 , last revised 26 Oct 2023 ]
代码:https://github.com/Yura52/tabular-dl-revisiting-models
https://github.com/yandex-research/tabular-dl-revisiting-models
摘要
本文回顾了为表格数据设计的各种深度学习模型,并强调了这些模型之间缺乏适当的比较,这使得很难确定性能最佳的模型。它确定了两种简单而强大的深度学习架构:类似 ResNet 的模型和经过调整的 Transformer 模型,这两个架构在不同任务中都表现出强劲的性能。作者使用相同的训练和调整协议将这些模型与许多现有架构进行了比较,以确保进行公平比较。他们还将最佳深度学习模型与梯度增强决策树进行了比较,并得出结论,没有适用于所有任务的普遍优越解决方案。他们的实验源代码已在 GitHub 上提供,允许其他研究人员在他们的工作基础上进行复制和构建。
1 Introduction(引言)
由于深度学习在图像、音频和文本等数据领域取得了巨大成功,研究者们对将这些成功扩展到表格数据问题上表现出了极大的兴趣。在这些问题中,数据点被表示为异构特征向量的集合,这在工业应用和机器学习竞赛中很常见。神经网络在这些领域有一个强大的非深度竞争对手,即梯度提升决策树(GBDT)。除了可能更高的性能外,使用深度学习处理表格数据的吸引力在于,它允许构建多模态流水线,其中输入的一部分是表格数据,其他部分包括图像、音频和其他对深度学习友好的数据。这样的流水线可以针对所有模态通过梯度优化进行端到端的训练。由于这些原因,最近提出了许多深度学习解决方案,并且新的模型不断涌现。
不幸的是,由于缺乏像计算机视觉领域的ImageNet或自然语言处理领域的GLUE这样的建立基准,现有的论文使用不同的数据集进行评估,提出的深度学习模型通常没有得到充分的相互比较。因此,从当前文献中不清楚哪个深度学习模型通常表现更好,以及GBDT是否被深度学习模型超越。此外,尽管提出了大量新架构,该领域仍然缺乏简单可靠的解决方案,这些解决方案可以在适度的努力下实现竞争性能,并在许多任务中提供稳定的表现。在这方面,多层感知器(MLP)仍然是该领域主要的简单基线,然而,它并不总是对其他竞争者构成重大挑战。
描述的问题妨碍了研究进程,使论文的观察结果不够有决定性。因此,作者认为,审查该领域的最新发展并提高表格深度学习的基线标准是及时的。作者从假设开始,即在表格数据的背景下,可能还没有充分探索在其他领域经过充分研究的深度学习架构块,这些块可以用来设计更好的基线。因此,作者从其他领域的知名且经过实战考验的架构中获得灵感,并为表格数据获得了两个简单的模型。第一个是一个类似ResNet的架构(He et al., 2015b),第二个是FT-Transformer——作者对Transformer架构的简单适配(Vaswani et al., 2017),用于表格数据。然后,作者在相同的训练和超参数调整协议下,将这些模型与许多现有解决方案进行了多样化任务的比较。首先,作者发现没有任何考虑的深度学习模型能够一致性地超越类似ResNet的模型。鉴于其简单性,它可以作为未来工作的强基线。其次,FT-Transformer在大多数任务上表现出最佳性能,成为该领域的一个新的强大解决方案。有趣的是,FT-Transformer被证明是表格数据的更通用架构:它在比更“传统”的ResNet和其他深度学习模型更广泛的任务范围内表现良好。最后,作者将最好的深度学习模型与GBDT进行了比较,并得出结论,仍然没有普遍优越的解决方案。
我们总结本文的贡献如下:
- 作者在一个多样化的任务集合上彻底评估了表格深度学习的主要模型,以调查它们的相对性能。
- 作者证明了一个简单的类似ResNet的架构是表格深度学习的有效基线,在现有文献中被忽略了。鉴于其简单性,我们推荐这个基线用于未来的表格深度学习工作。
- 作者介绍了FT-Transformer——一个针对表格数据的Transformer架构的简单适配,成为该领域的一个新的强大解决方案。作者观察到它是一个更通用的架构:它在比其他深度学习模型更广泛的任务范围内表现良好。
- 作者揭示了在GBDT和深度模型中仍然没有普遍优越的解决方案。
2 Related work(相关工作)
在表格数据问题上,“浅层”的最新技术目前是决策树集成,例如梯度提升决策树(GBDT)。GBDT通常是各种机器学习竞赛中的首要选择。目前,有几个GBDT库被广泛使用,包括XGBoost、LightGBM和CatBoost。尽管这些实现在细节上有所不同,但在大多数任务上,它们的性能并没有太大差异。 近年来,为了表格数据开发了大量的深度学习模型。这些模型大致可以分为以下三组:
可微分树:第一组模型受到决策树集成在表格数据上强大性能的启发。由于决策树不是可微分的,不允许梯度优化,它们不能作为以端到端方式训练的流水线的组件。为了解决这个问题,一些工作提出了在内部树节点“平滑”决策函数,使整体树函数和树路由可微分。
基于注意力的模型:由于注意力基础架构在不同领域(如图像和自然语言处理)的普遍成功,一些作者也提出在表格数据深度学习中使用类似注意力的模块。
显式建模乘法交互:在推荐系统和点击率预测的文献中,一些工作批评了多层感知器(MLP)不适合建模特征之间的乘法交互。受到这种动机的启发,一些工作提出了不同的方法将特征乘积融入MLP中。
此外,文献还提出了一些其他架构设计,这些设计不能明确归入上述任何一组。总体而言,社区开发了各种在不同基准上评估且很少相互比较的模型。作者的工作旨在建立一个公平的比较,并识别那些一致提供高性能的解决方案。
3 Models for tabular data problems(表格数据问题的模型)
在本节中,作者描述了工作中突出的主要深度学习架构,以及在比较中包含的现有解决方案。由于作者认为该领域需要强大的易于使用的基线,他们在设计ResNet(第3.2节)和FT-Transformer(第3.3节)时尽可能多地重用已建立的深度学习构建块。作者希望这种方法能够产生概念上熟悉的模型,这些模型需要较少的努力就能实现良好的性能。所有模型的附加讨论和技术细节都提供在补充材料中。
符号说明。在本工作中,我们考虑监督学习问题。数据集表示为 D = { ( x i , y i ) } i = 1 n D=\{(x_i,y_i)\}^n_{i=1} D={(xi,yi)}i=1n,其中 x i = ( x i ( n u m ) , x i ( c a t ) ) ∈ X x_i=(x^{(num)}_i , x^{(cat)}_i ) ∈ X xi=(xi(num),xi(cat))∈X代表数值 x i j ( n u m ) x^{(num)}_{ij} xij(num)和分类 x i j ( c a t ) x^{(cat)}_{ij} xij(cat) 特征的对象, y i ∈ Y y_i ∈ Y yi∈Y表示相应的对象标签。特征总数表示为k。数据集被分割为三个不相交的子集: D = D t r a i n ∪ D v a l ∪ D t e s t D = D_{train} ∪ D_{val} ∪ D_{test} D=Dtrain∪Dval∪Dtest,其中 D t r a i n D_{train} Dtrain用于训练, D v a l