Revisiting Deep Learning Models for Tabular Data
arxiv [Submitted on 22 Jun 2021 , last revised 26 Oct 2023 ]
代码:https://github.com/Yura52/tabular-dl-revisiting-models
https://github.com/yandex-research/tabular-dl-revisiting-models
摘要
本文回顾了为表格数据设计的各种深度学习模型,并强调了这些模型之间缺乏适当的比较,这使得很难确定性能最佳的模型。它确定了两种简单而强大的深度学习架构:类似 ResNet 的模型和经过调整的 Transformer 模型,这两个架构在不同任务中都表现出强劲的性能。作者使用相同的训练和调整协议将这些模型与许多现有架构进行了比较,以确保进行公平比较。他们还将最佳深度学习模型与梯度增强决策树进行了比较,并得出结论,没有适用于所有任务的普遍优越解决方案。他们的实验源代码已在 GitHub 上提供,允许其他研究人员在他们的工作基础上进行复制和构建。
1 Introduction(引言)
由于深度学习在图像、音频和文本等数据领域取得了巨大成功,研究者们对将这些成功扩展到表格数据问题上表现出了极大的兴趣。在这些问题中,数据点被表示为异构特征向量的集合,这在工业应用和机器学习竞赛中很常见。神经网络在这些领域有一个强大的非深度竞争对手,即梯度提升决策树(GBDT)。除了可能更高的性能外,使用深度学习处理表格数据的吸引力在于,它允许构建多模态流水线,其中输入的一部分是表格数据,其他部分包括图像、音频和其他对深度学习友好的数据。这样的流水线可以针对所有模态通过梯度优化进行端到端的训练。由于这些原因,最近提出了许多深度学习解决方案,并且新的模型不断涌现。
不幸的是,由于缺乏像计算机视觉领域的ImageNet或自然语言处理领域的GLUE这样的建立基准,现有的论文使用不同的数据集进行评估,提出的深度学习模型通常没有得到充分的相互比较。因此,从当前文献中不清楚哪个深度学习模型通常表现更好,以及GBDT是否被深度学习模型超越。此外,尽管提出了大量新架构,该领域仍然缺乏简单可靠的解决方案,这些解决方案可以在适度的努力下实现竞争性能,并在许多任务中提供稳定的表现。在这方面,多层感知器(MLP)仍然是该领域主要的简单基线,然而,它并不总是对其他竞争者构成重大挑战。
描述的问题妨碍了研究进程,使论文的观察结果不够有决定性。因此,作者认为,审查该领域的最新发展并提高表格深度学习的基线标准是及时的。作者从假设开始,即在表格数据的背景下,可能还没有充分探索在其他领域经过充分研究的深度学习架构块,这些块可以用来设计更好的基线。因此,作者从其他领域的知名且经过实战考验的架构中获得灵感,并为表格数据获得了两个简单的模型。第一个是一个类似ResNet的架构(He et al., 2015b),第二个是FT-Transformer——作者对Transformer架构的简单适配(Vaswani et al., 2017),用于表格数据。然后,作者在相同的训练和超参数调整协议下,将这些模型与许多现有解决方案进行了多样化任务的比较。首先,作者发现没有任何考虑的深度学习模型能够一致性地超越类似ResNet的模型。鉴于其简单性,它可以作为未来工作的强基线。其次,FT-Transformer在大多数任务上表现出最佳性能,成为该领域的一个新的强大解决方案。有趣的是,FT-Transformer被证明是表格数据的更通用架构:它在比更“传统”的ResNet和其他深度学习模型更广泛的任务范围内表现良好。最后,作者将最好的深度学习模型与GBDT进行了比较,并得出结论,仍然没有普遍优越的解决方案。
我们总结本文的贡献如下:
- 作者在一个多样化的任务集合上彻底评估了表格深度学习的主要模型,以调查它们的相对性能。
- 作者证明了一个简单的类似ResNet的架构是表格深度学习的有效基线,在现有文献中被忽略了。鉴于其简单性,我们推荐这个基线用于未来的表格深度学习工作。
- 作者介绍了FT-Transformer——一个针对表格数据的Transformer架构的简单适配,成为该领域的一个新的强大解决方案。作者观察到它是一个更通用的架构:它在比其他深度学习模型更广泛的任务范围内表现良好。
- 作者揭示了在GBDT和深度模型中仍然没有普遍优越的解决方案。
2 Related work(相关工作)
在表格数据问题上,“浅层”的最新技术目前是决策树集成,例如梯度提升决策树(GBDT)。GBDT通常是各种机器学习竞赛中的首要选择。目前,有几个GBDT库被广泛使用,包括XGBoost、LightGBM和CatBoost。尽管这些实现在细节上有所不同,但在大多数任务上,它们的性能并没有太大差异。 近年来,为了表格数据开发了大量的深度学习模型。这些模型大致可以分为以下三组:
可微分树:第一组模型受到决策树集成在表格数据上强大性能的启发。由于决策树不是可微分的,不允许梯度优化,它们不能作为以端到端方式训练的流水线的组件。为了解决这个问题,一些工作提出了在内部树节点“平滑”决策函数,使整体树函数和树路由可微分。
基于注意力的模型:由于注意力基础架构在不同领域(如图像和自然语言处理)的普遍成功,一些作者也提出在表格数据深度学习中使用类似注意力的模块。
显式建模乘法交互:在推荐系统和点击率预测的文献中,一些工作批评了多层感知器(MLP)不适合建模特征之间的乘法交互。受到这种动机的启发,一些工作提出了不同的方法将特征乘积融入MLP中。
此外,文献还提出了一些其他架构设计,这些设计不能明确归入上述任何一组。总体而言,社区开发了各种在不同基准上评估且很少相互比较的模型。作者的工作旨在建立一个公平的比较,并识别那些一致提供高性能的解决方案。
3 Models for tabular data problems(表格数据问题的模型)
在本节中,作者描述了工作中突出的主要深度学习架构,以及在比较中包含的现有解决方案。由于作者认为该领域需要强大的易于使用的基线,他们在设计ResNet(第3.2节)和FT-Transformer(第3.3节)时尽可能多地重用已建立的深度学习构建块。作者希望这种方法能够产生概念上熟悉的模型,这些模型需要较少的努力就能实现良好的性能。所有模型的附加讨论和技术细节都提供在补充材料中。
符号说明。在本工作中,我们考虑监督学习问题。数据集表示为 D = { ( x i , y i ) } i = 1 n D=\{(x_i,y_i)\}^n_{i=1} D={(xi,yi)}i=1n,其中 x i = ( x i ( n u m ) , x i ( c a t ) ) ∈ X x_i=(x^{(num)}_i , x^{(cat)}_i ) ∈ X xi=(xi(num),xi(cat))∈X代表数值 x i j ( n u m ) x^{(num)}_{ij} xij(num)和分类 x i j ( c a t ) x^{(cat)}_{ij} xij(cat) 特征的对象, y i ∈ Y y_i ∈ Y yi∈Y表示相应的对象标签。特征总数表示为k。数据集被分割为三个不相交的子集: D = D t r a i n ∪ D v a l ∪ D t e s t D = D_{train} ∪ D_{val} ∪ D_{test} D=Dtrain∪Dval∪Dtest,其中 D t r a i n D_{train} Dtrain用于训练, D v a l D_{val} Dval用于早期停止和超参数调整, D t e s t D_{test} Dtest用于最终评估。我们考虑三种类型的任务:二元分类 Y = { 0 , 1 } Y = \{0, 1\} Y={0,1},多类分类 Y = { 1 , . . . , C } Y = \{1, . . . , C\} Y={1,...,C}和回归 Y = R Y = R Y=R。
3.1 MLP
我们在公式1中形式化了“MLP”架构。
M L P ( x ) = L i n e a r ( M L P B l o c k ( . . . ( M L P B l o c k ( x ) ) ) ) MLP(x) = Linear (MLPBlock (. . . (MLPBlock(x)))) MLP(x)=Linear(MLPBlock(...(MLPBlock(x))))
M L P B l o c k ( x ) = D r o p o u t ( R e L U ( L i n e a r ( x ) ) ) MLPBlock(x) = Dropout(ReLU(Linear(x))) MLPBlock(x)=Dropout(ReLU(Linear(x)))
3.2 ResNet
作者知道有一次尝试设计一个类似ResNet的基线,但是报告的结果并不具有竞争力。然而,鉴于ResNet在计算机视觉和最近在自然语言处理任务中的成功,他们构建了一个简单的ResNet变种,如公式2所述。主要的构建块与原始架构相比被简化,作者发现这对于优化是有益的。
R e s N e t ( x ) = P r e d i c t i o n ( R e s N e t B l o c k ( . . . ( R e s N e t B l o c k ( L i n e a r ( x ) ) ) ) ) ResNet(x) = Prediction (ResNetBlock (. . . (ResNetBlock (Linear(x))))) ResNet(x)=Prediction(ResNetBlock(...(ResNetBlock(Linear(x)))))
R e s N e t B l o c k ( x ) = x + D r o p o u t ( L i n e a r ( D r o p o u t ( R e L U ( L i n e a r ( B a t c h N o r m ( x ) ) ) ) ) ) ResNetBlock(x) = x + Dropout(Linear(Dropout(ReLU(Linear(BatchNorm(x)))))) ResNetBlock(x)=x+Dropout(Linear(Dropout(ReLU(Linear(BatchNorm(x))))))
P r e d i c t i o n ( x ) = L i n e a r ( R e L U ( B a t c h N o r m ( x ) ) ) Prediction(x) = Linear (ReLU (BatchNorm (x))) Prediction(x)=Linear(ReLU(BatchNorm(x)))
3.3 FT-Transformer
在本节中,作者介绍了FT-Transformer(Feature Tokenizer + Transformer)——一个简单的Transformer架构适配,用于表格数据领域。图1展示了FT-Transformer的主要内容。作者的模型将所有特征(分类和数值)转换为嵌入,并在嵌入上应用一系列Transformer层。每个Transformer层操作一个对象的特征级别的表示。
图1:FT-Transformer架构。首先,Feature Tokenizer将特征转换为嵌入。然后,嵌入通过Transformer模块处理,并使用[CLS]标记的最终表示进行预测。
图2:(a)Feature Tokenizer;在这个例子中,有三个数值和两个分类特征;(b)一个Transformer层。
Feature Tokenizer。Feature Tokenizer模块(见图2)将输入特征x转换为嵌入 T ∈ R k × d T ∈ R^{k×d} T∈Rk×d。给定特征xj的嵌入Tj计算如下: T j = b j + f j ( x j ) ∈ R d T_j = b_j + f_j(x_j) ∈ R^d Tj=bj+fj(xj)∈Rd f j : X j → R d f_j : X_j → R^d fj:Xj→Rd。
其中 b j b_j bj是第j个特征的偏置, f j ( n u m ) f^{(num)}_j fj(num)作为元素乘法实现为向量 W j ( n u m ) ∈ R d W^{(num)}_j ∈ R^d Wj(num)∈Rd, f j ( c a t ) f^{(cat)}_j fj(cat)作为查找表 W j ( c a t ) ∈ R S j × d W^{(cat)}_j ∈ R^{S_j×d} Wj(cat)∈RSj×d实现为分类特征。总的来说:
T j ( n u m ) = b j ( n u m ) + x j ( n u m ) ⋅ W j ( n u m ) ∈ R d T^{(num)}_j = b^{(num)}_j + x^{(num)}_j · W^{(num)}_j ∈ R^d Tj(num)=bj(num)+xj(num)⋅Wj(num)∈Rd,
T j ( c a t ) = b j ( c a t ) + e j T W j ( c a t ) ∈ R d T^{(cat)}_j = b^{(cat)}_j + e^T_j W^{(cat)}_j ∈ R^d Tj(cat)=bj(cat)+ejTWj(cat)∈Rd,
T = s t a c k ( T 1 ( n u m ) , . . . , T k ( n u m ) ( n u m ) , T 1 ( c a t ) , . . . , T k ( c a t ) ( c a t ) ) ∈ R k × d T = stack(T^{(num)}_1 , . . . , T^{(num)}_{k^{(num)}}, T^{(cat)}_1 , . . . , T^{(cat)}_{k^{(cat)}}) ∈ R^{k×d} T=stack(T1(num),...,Tk(num)(num),T1(cat),...,Tk(cat)(cat))∈Rk×d。
其中 e j T e^T_j ejT是相应分类特征的独热向量。
Transformer。在这个阶段,将[CLS]标记的嵌入(或“分类标记”或“输出标记”)附加到T上,然后应用L层Transformer F 1 , . . . , F L F_1, . . . , F_L F1,...,FL:
T 0 = s t a c k [ [ C L S ] , T ] T_0 = stack[[CLS], T] T0=stack[[CLS],T]
T i = F i ( T i − 1 ) T_i = F_i(T_{i−1}) Ti=Fi(Ti−1)。
我们使用PreNorm变体以便于优化,见图2。在PreNorm设置中,我们发现有必要从第一个Transformer层中移除第一个标准化以实现良好的性能。有关激活、标准化的位置和dropout模块的详细信息,请参见补充材料。 预测。使用[CLS]标记的最终表示进行预测:
y ^ = L i n e a r ( R e L U ( L a y e r N o r m ( T L [ C L S ] ) ) ) \hat y= Linear(ReLU(LayerNorm(T^{[CLS]}_L))) y^=Linear(ReLU(LayerNorm(TL[CLS])))。
限制。FT-Transformer需要比简单模型如ResNet更多的资源(硬件和时间)进行训练,并且可能不容易扩展到特征数量非常大的数据集。因此,FT-Transformer在解决表格数据问题上的广泛使用可能导致由机器学习流水线产生的CO2排放量增加。这个问题的主要原因在于原始MHSA对于特征数量具有二次复杂度。然而,这个问题可以通过使用MHSA的有效近似来缓解。此外,仍然可以将FT-Transformer蒸馏成更简单的架构以提高推理性能。
3.4 其他模型
在这一部分,我们列出了专门为表格数据设计的现有模型,并将它们包含在比较中。
• SNN。一个类似于MLP的架构,使用SELU激活,可以训练更深层次的模型。
• NODE。一个可微分的无记忆决策树集成。
• TabNet。一个循环架构,交替进行特征的动态重加权和传统的前馈模块。
• GrowNet。梯度提升的弱MLP。官方实现仅支持分类和回归问题。
• DCN V2。由一个类似于MLP的模块和特征交叉模块(线性层和乘法的组合)组成。
• AutoInt。将特征转换为嵌入,并应用一系列基于注意力的转换到嵌入上。
• XGBoost。最受欢迎的GBDT实现之一。
• CatBoost。一个GBDT实现,使用无记忆决策树作为弱学习器。
4 Experiments
在本节中,作者将深度学习模型(DL)相互比较,并与梯度提升决策树(GBDT)进行比较。注意,在正文中,作者只报告了关键结果。在补充材料中,作者提供了:(1)所有模型在所有数据集上的结果;(2)硬件信息;(3)ResNet和FT-Transformer的训练时间。
4.1 比较范围
在本工作中,作者专注于不同架构的相对性能,并不采用各种模型无关的深度学习实践,例如预训练、额外的损失函数、数据增强、蒸馏、学习率预热、学习率衰减等。虽然这些实践可能会提高性能,但作者的目标是评估由不同模型架构施加的归纳偏差的影响。
4.2 数据集
作者使用了一组包含十一个公共数据集的多样化数据集。对于每个数据集,都有一个完全相同的训练-验证-测试分割,因此所有算法都使用相同的分割。数据集包括:加利福尼亚住房(CA,房地产数据)、成人(AD,收入估计)、海伦娜(HE,匿名数据集)、詹尼斯(JA,匿名数据集)、希格斯(HI,模拟物理粒子)、ALOI(AL,图像)、Epsilon(EP,模拟物理实验)、年份(YE,音频特征)、覆盖类型(CO,森林特征)、雅虎(YA,搜索查询)、微软(MI,搜索查询)。我们按照逐点学习方法处理学习排序问题,并将排名问题(微软、雅虎)视为回归问题。数据集属性总结在表1中。
4.3 实现细节
数据预处理。对于每个数据集,所有深度模型都使用相同的预处理步骤,以确保公平比较。默认情况下,作者使用了Scikit-learn库中的分位数变换。对于海伦娜和ALOI数据集,作者应用了标准化(均值减去和缩放),因为这是计算机视觉中常见的做法。在Epsilon数据集上,作者观察到预处理对深度模型的性能有负面影响,因此在该数据集上使用原始特征。对于所有算法,作者对回归目标进行了标准化处理。
超参数调整。对于每个数据集,作者为每个模型精心调整超参数。在验证集上表现最佳的超参数被选为最终参数,因此测试集从未用于调整。对于大多数算法,作者使用Optuna库运行贝叶斯优化,这被认为优于随机搜索。对于其他算法,作者在相应论文推荐的预定义配置集上进行迭代。作者提供了参数空间和网格,并在补充材料中提供了关于如何根据时间设置预算的额外分析。
评估。对于每个调整后的配置,作者使用不同的随机种子运行15次实验,并报告测试集上的性能。对于一些算法,作者还报告了没有进行超参数调整的默认配置的性能。
集成。对于每个模型,在每个数据集上,作者通过将15个单一模型分成三个不相交的等大小组,并计算每个组内单一模型预测的平均值来获得三个集成。
神经网络。作者对分类问题使用交叉熵损失进行优化,对回归问题使用均方误差损失进行优化。对于TabNet和GrowNet,作者遵循原始实现并使用Adam优化器。对于所有其他算法,作者使用AdamW优化器。作者没有应用学习率调度。对于每个数据集,作者为所有算法使用预定义的批量大小,除非在相应论文中给出了关于批量大小的特别说明。
分类特征。对于XGBoost,作者使用独热编码。对于CatBoost,作者利用其内置的对分类特征的支持。对于神经网络,作者为所有分类特征使用相同维度的嵌入。
4.4 比较深度学习模型
表2报告了深度学习架构的结果。主要的发现包括:
- 多层感知器(MLP)仍然是一个良好的基准。
- ResNet被证明是一个有效的基线,没有一个竞争对手能够一致性地超越它。
- FT-Transformer在大多数任务上表现最佳,成为该领域的一个新的强大解决方案。
- 适当的调整可以使简单的模型如MLP和ResNet具有竞争力,因此我们建议在可能的情况下对基线进行调整。幸运的是,现在有了像Optuna这样的库,使得调整变得更加容易。
在其他模型中,NODE(Popov et al., 2020)是唯一在几个任务上表现出高性能的模型。然而,它在六个数据集(Helena、Jannis、Higgs、ALOI、Epsilon、Covertype)上仍然不如ResNet,同时它是一个更复杂的解决方案。此外,它不是一个真正的“单一”模型;实际上,它通常包含的参数数量比ResNet和FT-Transformer显著更多,并且具有集成式结构。我们在表3中通过比较集成来说明这一点。结果表明,FT-Transformer和ResNet从集成中受益更多;在这种情况下,FT-Transformer超越了NODE,而ResNet和NODE之间的差距显著缩小。尽管如此,NODE仍然是基于树的方法中的一个突出解决方案。
4.5 将深度学习模型与GBDT进行比较
在本节中,我们的目标是检查深度学习模型是否在概念上准备好超越GBDT。为此,作者比较了使用GBDT或深度学习模型可以实现的最佳可能指标值,不考虑速度和硬件要求(毫无疑问,GBDT是更轻量级的解决方案)。作者通过比较集成而不是单一模型来实现这一点,因为GBDT本质上是一种集成技术,作者期望深度架构从集成中受益更多(Fort et al., 2020)。在表4中报告了结果。
4.6 FT-Transformer的一个有趣的属性
表4还揭示了一个重要的结论。即,FT-Transformer在超越“传统”的深度学习模型ResNet方面提供了大部分优势,正是在GBDT优于ResNet的这些问题上(加利福尼亚住房、成年人、Covertype、雅虎、微软),而在其余问题上与ResNet表现相当。换句话说,FT-Transformer在所有任务上都提供了竞争性能,而GBDT和ResNet只在任务的某些子集上表现良好。这一观察可能是FT-Transformer是表格数据问题的更“通用”模型的证据。
5 Analysis(分析)
5.1 When FT-Transformer is better than ResNet?
在5.1节中,作者探讨了FT-Transformer在哪些情况下比ResNet表现得更好。通过设计一系列的合成任务,他们控制了数据的特征并逐渐改变了目标函数,从而观察两种模型的性能差异。实验结果显示,当目标函数更倾向于GBDT(梯度提升决策树)的风格时,ResNet的性能显著下降,而FT-Transformer则在所有任务上都保持了竞争性能。这表明FT-Transformer可能比ResNet更能泛化到不同类型的数据和任务。
5.2 Ablation study
5.2节进行了消融研究,以测试FT-Transformer的不同设计选择。作者首先将FT-Transformer与AutoInt进行了比较,AutoInt也是一种基于注意力机制的模型,但两者在嵌入层和主干结构上有所不同。消融研究还包括了去除FT-Transformer中的特征偏置,以评估其对性能的影响。结果显示,FT-Transformer在没有特征偏置的情况下性能下降,证明了特征偏置对于模型性能的重要性。
5.3 Obtaining feature importances from attention maps
在5.3节中,作者探讨了如何从FT-Transformer的注意力图中获取特征重要性。他们提出了一种基于注意力图平均值的方法来评估特征的重要性,并将其与集成梯度(IG)方法进行了比较。结果表明,注意力图的平均值可以作为一种高效的方式来估计特征重要性,并且与IG方法的结果具有相似的排名相关性。这表明,尽管IG是一种更为通用的方法,但在效率和成本效益方面,注意力图的平均值是一个合理的替代方案。
这些分析有助于更深入地理解FT-Transformer的工作原理,以及它在特定情况下为何能够超越其他模型。通过这些实验,作者提供了关于FT-Transformer性能和特征重要性估计的见解。
6 Conclusion(结论)
在这项工作中,作者调查了表格数据深度学习的当前状态,并提高了表格深度学习基线的标准。首先,作者展示了一个简单的类似ResNet的架构可以作为一个有效的基线。其次,作者提出了FT-Transformer——一个简单的Transformer架构适配,它在大多数任务上的性能超过了其他深度学习解决方案。作者还比较了新的基线与GBDT,并证明了GBDT在某些任务上仍然占主导地位。代码和所有研究的细节都是开源的,作者希望他们的评估和两个简单的模型(ResNet和FT-Transformer)将作为表格深度学习进一步发展的基石。
最后感谢你看到这里,以上观点均为本人对原论文的个人理解,仅作个人学习使用,如有错误或侵权,麻烦联系我,本人定修改或删除。
祝你天天开心,多笑笑。