TABR: TABULAR DEEP LEARNING MEETS NEAREST NEIGHBORS IN 2023
论文地址:https://arxiv.org/abs/2307.14338
源代码:https://github.com/yandex-research/tabular-dl-tabr
摘要
针对表格数据问题(例如分类、回归)的深度学习(DL)模型越来越受欢迎,而基于梯度提升决策树(GBDT)的非DL算法仍然是这些问题的一个强有力的解决方案。
利用检索增强模型是当前提高表格式深度学习的地位的方法之一。检索增强的方式:对于目标对象,这些模型从可用的训练数据中检索其他对象(例如最近的邻居),并使用它们的特征和标签来进行更好的预测。
本项工作提出了 TabR:它能够利用检索增强的方式提升表格的深度学习的性能,它在本质上是一个前馈网络,中间有一个自定义的k-最近邻组件。
- 性能很好,在“GBDT-friendly” benchmark等数据集上表现很好。
- 主要的发现和技术细节在于负责检索最近邻居并从中提取有价值信号的注意力机制。
图一:比较DL模型与XGBoost在中等规模(≤ 50K对象)的43个回归和分类任务上的差异。
该数据来自“Why do tree-based models still outperform deep learning on typical tabular data?” 。这篇文章做了在DL和树模型上的对比。可以看见TabR有了很大的突破。
一. 介绍
表格上的ML问题,原来是梯度提升决策树是首选,现在DL模型收到了很多关注
检索增强的表格DL模型
检索增强模型从训练集中检索其他对象
- G. Somepalli, M. Goldblum, A. Schwarzschild, C. B. Bruss, and T. Goldstein. SAINT: improved
neural networks for tabular data via row attention and contrastive pre-training. arXiv, 2106.01342v1,
- 1, 3, 5, 7, 25
- Qin, W. Zhang, R. Su, Z. Liu, W. Liu, R. Tang, X. He, and Y. Yu. Retrieval & interaction machine
for tabular data prediction. In KDD, 2021. 1, 2 - J. Kossen, N. Band, C. Lyle, A. N. Gomez, T. Rainforth, and Y. Gal. Self-attention between datapoints: Going beyond individual input-output pairs in deep learning. In NeurIPS, 2021. 1, 2, 3, 5, 6, 7, 15, 25
检索增强的好处:与纯参数(即无检索)模型相比,基于检索的模型可以实现更高的性能,并且还表现出几个实际上重要的特性,例如增量学习的能力和更好的鲁棒性(Das等人,2021; Jia等人 2021年)。
贡献
前人的论文他们提供,如果只有轻微的好处,适当调整多层感知器(MLP;最简单的参数模型),而显着更复杂和昂贵的。
贡献总结:
- 设计了TabR -一个简单的检索增强的表格DL模型,在数据集中表现好
- 具有中等规模任务的基准上优于GBDT
- 强调了重要的自由度的注意力机制(检索为基础的模型中经常使用的模块),允许设计更好的检索为基础的表格模型。
二、相关工作
1. Gradient boosted decision trees (GBDT)
- 梯度提升决策树(GBDT)基于GBDT的ML模型是表格数据监督问题的非DL解决方案
2. Parametric deep learning models.
旨在将深度学习的好处带到表格数据世界,同时实现有竞争力的性能
参数化表格DL
MLP样主链仍然具有竞争性,以及用于连续特征的嵌入显著地减小了表格式DL和GBDT之间的差距。在这项工作中,我们表明,一个适当设计的检索组件可以进一步提高性能的表格DL。
3. 一般的检索增强生成
一般的检索增强模型。通常,基于检索的模型设计如下。
- 对于输入对象,首先,它们从可用(训练)数据中检索相关样本。然后,它们将输入对象与检索到的实例一起处理,以生成输入对象的最终预测。
- 设计基于检索的方案的常见动机之一是局部学习范例(Bottou和Vapnik,1992),并且这种模型的最简单的可能示例是k-最近邻(kNN)算法(James等人,2013年)。
- 基于检索的方法的前景在各个领域得到了证明,例如自然语言处理(刘易斯et al.,2020; Guu等人,2020; Khandelwal等人,2020; Izacard等人,2022年; Borgeaud等人,2022)、计算机视觉(Iscen等人,2022; Long等人,2022)、CTR预测(Qin等人,2020; 2021; Du等人,#20222;及其他。
- 此外,检索增强模型通常具有有用的属性,例如更好的可解释性(Wang和Sabuncu,2023),鲁棒性(Zhao和Cho,2018)等。
4. 表格数据的检索增强生成
表格数据问题的检索增强模型。
- 基于非深度检索的表格模型的经典示例是基于邻居的方法和核方法(James等人,2013; Nader等人,2022年)
- Y. Nader, L. Sixt, and T. Landgraf. Dnnr: Differential nearest neighbors regression. In ICML, 2022. 2, 5, 7
- 还有适用于(或直接设计用于)表格数据问题的基于深度检索的模型
- (Ramsauer等人,2021;科森等人,2021; Somepalli等人,2021)和下面的相同
它们中的一些省略了检索步骤,并使用所有训练数据点作为“检索”实例
现有的缺点:然而,我们表明,现有的基于检索的表格DL模型只是略优于简单的参数DL模型,并且通常使用重型变压器式架构的成本。
本文的优势:在先前的工作中,经常使用在对象和特征之间具有多头注意力的几个层(Ramsauer等人,2021;科森等人,2021; Somepalli等人,2021)
- Ramsauer, B. Schäfl, J. Lehner, P. Seidl, M. Widrich, L. Gruber, M. Holzleitner, T. Adler, D. P.
Kreil, M. K. Kopp, G. Klambauer, J. Brandstetter, and S. Hochreiter. Hopfield networks is all you need. In ICLR, 2021. 2, 3 - Kossen, N. Band, C. Lyle, A. N. Gomez, T. Rainforth, and Y. Gal. Self-attention between datapoints: Going beyond individual input-output pairs in deep learning. In NeurIPS, 2021. 1, 2, 3, 5, 6, 7, 15, 25
- G. Somepalli, M. Goldblum, A. Schwarzschild, C. B. Bruss, and T. Goldstein. SAINT: improved
neural networks for tabular data via row attention and contrastive pre-training. arXiv, 2106.01342v1, 2021. 1, 3, 5, 7, 25
我们的模型TabR只使用一个类似注意力的单头模块来实现其检索组件。重要的是,TabR的单个注意力模块是定制的,使其更适合表格数据问题。因此,TabR大大优于现有的基于检索的DL模型,同时效率更高。
三. TABR
1. 准备
1) 标注
xi表示第i个对象的特征,yi表示第i个对象的标签
考虑三种类型的任务:二元分类Y = {0,1},多类分类Y = {1,…C}和回归Y = R
默认特征是连续变量,对于二元or多元的特征额外标注
数据集被分割为三部分
- 训练集:用于训练
- 验证集Validation:用于早停止和参数微调
- 测试集:用于最终的评估
检索:在“上下文候选者”或简称为“候选者”的集合内执行(在train数据集内)
- 检索到的对象又被称为“上下文对象”或简称为“上下文”。
- 在这项工作中,我们对所有输入对象使用相同的候选集。
2) 实验设置
我们在D.6小节中详细描述了我们的调优和评估协议。最重要的一点是,对于任何给定的算法,在每个数据集上,遵循Gorishniy等人(2022)
- Y. Gorishniy, I. Rubachev, and A. Babenko. On embeddings for numerical features in tabular deep learning. In NeurIPS, 2022. 1, 2, 3, 4, 6, 14, 19, 20, 21, 22, 27, 28——CCF A
- (1)我们使用验证集执行超参数调整和早期停止;
- (2)对于最佳超参数,在正文中,我们报告了测试集上平均15个随机种子的度量,并在附录E中提供了标准差;
- (3)当比较任何两种算法时,我们考虑标准偏差,如D.6小节所述;
- (4)为了获得相同类型模型的集合,我们将15个随机种子分成三个不相交的组(即,分成三个集合),每个集合由五个模型组成,对每组内的预测进行平均,并报告所获得的三个集合的平均性能。
数据集:我们主要使用以前文献中的数据集,并在表1中提供其摘要
2. 架构
为了建立一个基于检索的表格DL模型,我们选择了一个增量的方法,我们从一个简单的没有检索的架构,并逐步增加和改进的检索组件。
首先,用同一编码器E对目标对象及其检索候选进行编码。
然后,检索模块R通过从候选者中检索和处理相关对象来丰富目标对象的表示。
最后,预测器P做出预测。粗体路径突出显示了在添加检索模块R之前前馈无检索模型的结构
- 将候选表的标签也输入到检索模块中
1)编码器和预测器
不是本文的重点,因此只用保持简单就可以
输入模块包块输入处理路线:
2)检索模块
k-最近邻检索
1.如果编码器E包含至少一个块(即NE > 0),则在被传递到R之前,利用共享层归一化来对Nmax和所有Nmax Xi进行归一化。
2.可选地,目标对象本身可以无条件地(即,忽略前m个操作)作为第(m+ 1)个对象添加到具有相似性得分S(m x,m x)的其上下文对象集合。——即可以将目标对象加入到上下文对象集合中。
3.Dropout应用于softmax函数产生的权重。4.在整个论文中,我们使用m = 96和Icand = Itrain(候选对象等于训练集)。
现在,我们讨论相似性模块S和值模块V(在图4中介绍)的可能设计。
在此过程中,我们不使用数字特征的嵌入(等人,2022),并设置NE = 0,NP = 1(见图3)。
- Y. Gorishniy, I. Rubachev, and A. Babenko. On embeddings for numerical features in tabular deep learning. In NeurIPS, 2022. 1, 2, 3, 4, 6, 14, 19, 20, 21, 22, 27, 28
图4:图2中引入的检索模块R的简化图示
对于目标对象的表示法n x,该模块根据相似性模块S:(Rd,Rd)→ R在候选对象{n Xi}中选取m个最近邻,并聚合它们由值模块V:(Rd,Rd,Y)→ Rd产生的值。
对于目标对象的表示法x~,该模块根据相似性模块S:(Rd,Rd)→ R在候选对象{Xi~}中选取m个最近邻,并聚合它们由值模块V: (Rd, Rd, Y) → Rd. (即最后将结果转换为d维向量)
3)实现步骤
步骤0:vanilla-attention-like baseline:
自我注意操作(Vaswani等人,2017)在先前的工作中经常用于对目标对象和候选/上下文对象之间的交互进行建模
然后,将检索模块R实例化为vanilla self-attention(对top-m操作取模)是合理的基线:其中WQ、WK和WV是线性层,并且目标对象作为第(m+ 1)个对象添加到其自己的上下文(即,忽略top-m操作)。
步骤-1:添加上下文标签。
改进Step-0配置的自然尝试是利用上下文对象的标签,例如,通过将它们合并到值模块中,如下所示:
表2显示使用标签没有任何改进,这是违反直觉的。
步骤2: 改进了相似度模块S
根据经验,我们观察到删除查询的概念(即删除WQ)并使用L2距离而不是点积显着提高了表2中几个数据集的性能:
在A.3小节中,我们证明了删除三个成分中的任何一个(上下文标签,仅键表示,L2距离)都会导致性能下降到MLP的水平。虽然L2距离不太可能是问题的普遍最佳选择(即使在表格域中),但它似乎是表格数据问题的合理默认选择
步骤3:改进价值模块V。
现在,我们从DNNR(Nader等人,2022)-最近提出的kNN算法的回归问题的推广。也就是说,我们通过考虑目标对象的表示法来使值模块V更具表现力
其中与等式3的差异被加下划线。表2显示,新的值模块进一步提高了几个数据集的性能。直观地说,项WY(yi)(上下文对象的标签的嵌入)可以被视为第i个上下文对象的“原始”贡献。项T(WK(x)-WK(Xi))可以被视为“校正”项,其中模块T将密钥空间中的差异转换为标签嵌入空间中的差异。
步骤4:TabR
根据经验,我们观察到,在相似度模块中省略缩放项d−1/2,并且不将目标对象包含在其自己的上下文中,平均而言会产生更好的结果,
其中WK是线性层,WY是用于分类任务的嵌入表和用于回归任务的线性层,(默认情况下)目标对象不包括在其自身的上下文中,(默认情况下)相似性得分不缩放
局限性。TabR具有检索增强模型的标准限制,我们在附录B中对此进行了描述。我们鼓励从业者在实践中使用TabR之前审查限制。
实验
我们将TabR(在第3节中介绍)与现有的基于检索的解决方案和最先进的参数模型进行比较。
除了TabR的完全成熟的配置(如图3所述,E和P具有所有可用的自由度)之外,我们还使用TabR-S(“S”代表“简单”)-一种简单的配置,它不使用特征嵌入(Gorishniy等人,2022)具有线性编码器(NE = 0)和一个块预测器(NP = 1)。我们指定何时TabR-S仅用于表格、图形和标题,而不用于文本。有关TabR的其他详细信息,包括超参数调优,请参见D.8小节。
1. 评估用于表格数据的检索增强深度学习模型
在本节中,我们将TabR(第3节)和现有的检索增强解决方案与完全参数化DL模型进行比较(所有算法的实现细节请参见附录D)。
所获得的结果突出了检索技术和数字特征的嵌入(Gorishniy等人,2022)(用于MLP-PLR和TabR)作为两个强大的架构元素,提高了表格DL模型的优化属性。有趣的是,这两种技术并不是完全正交的,但它们都不能恢复另一种技术的全部能力,这取决于给定的数据集,是否应该选择检索,嵌入或两者的组合。
这两种技术并不是完全正交的,但它们都不能恢复另一种技术的全部能力,这取决于给定的数据集,是否应该选择检索,嵌入或两者的组合。
2. 比较TabR和梯度增强决策树
在本节中,我们将TabR与基于梯度提升决策树(GBDT)的模型进行了比较:
- XGBoost(Chen和Guestrin,2016),LightGBM(Ke等人,2017)和CatBoost(Prokhorenkova等人,2018年)。
具体来说,我们比较集成(例如TabR的集成与XGBoosts的集成)以进行公平的比较,因为梯度提升已经是一种集成技术。
五. 分析
1. 冻结内容以加快TABR训练
在TabR的公式中(第3节),对于每个训练批次,最新的上下文是通过编码所有候选项并计算与所有候选项的相似性来挖掘的,这在大型数据集上可能非常慢。
然而,如图5所示,对于一个平均的训练对象,它的上下文(即前m个候选者和根据相似性模块S在它们上的分布)在训练过程中逐渐“稳定”,这为简单的优化提供了机会。
也就是说,**在固定数量的时期之后,我们可以执行“上下文冻结”:即,最后一次计算所有训练(但不是验证和测试)对象的最新上下文,然后在其余的训练中重用这些上下文。**表6表明,在某些数据集上,这种简单的技术可以加速TabR的训练,而不会损失太多指标,并且在较大的数据集上有更明显的加速。特别是,在完整的“天气预测”数据集上,我们实现了近七倍的加速(从18 h9 min到3 h15 min),同时保持了竞争力的RMSE。实施细节见D.2小节。
2. 使用新训练数据验证TABR,无需重新训练
在训练机器学习模型之后访问新的看不见的训练数据(例如,在收集应用程序的日常日志的又一部分之后)是常见的实际场景。
从技术上讲,TabR允许使用新数据,而无需通过将新数据添加到候选检索集合中进行重新训练。
我们在完整的“天气预测”数据集上测试了这种方法(Malinin等人,2021年)(3M+对象)。图6表明,这种“在线更新”可能是将新数据合并到已经训练的TabR中的可行解决方案。此外,这种方法可以通过在数据子集上训练模型并从完整数据中检索来将TabR扩展到大型数据集。总的来说,我们认为所进行的实验是一个初步的探索,并留下一个系统的研究,为未来的工作不断更新。
实施细节见D.3小节。
3. 深入分析
在附录中,我们提供了更有见地的分析。
- 在A.1小节中,我们分析了在3.2小节的步骤2中引入的基于L2的仅键相似度模块S,这是我们故事的转折点。我们提供了S的这种具体实现背后的直觉,并与vanilla attention(查询和键之间的点积)的相似性模块进行了深入的比较。
- 在A.2小节中,我们分析在3.2小节的步骤3中引入的价值模块V。在回归问题上,我们从方程4中确认了模块T的校正语义。
- 在A.4小节中,我们将TabR的训练时间与所有基线的训练时间进行比较。我们表明,与以前的基于检索的表格模型相比,TabR在效率方面向前迈出了一大步。虽然TabR比简单的无检索模型相对较慢,但在考虑的数据集大小范围内,TabR的绝对训练时间对于大多数实际场景来说都是可以承受的。
- 在A.7小节中,我们强调了TabR的其他技术属性。
六. 总结和未来工作
我们强调了相似性和价值模块的重要细节的注意力机制有显着的影响,基于注意力的检索组件的性能。
- 未来工作的一个重要方向是提高检索增强模型的效率,使其更快,特别是适用于数千万和数亿个数据点。
此外,在本文中,我们更多地关注任务性能方面,因此TabR的其他一些属性仍然没有得到充分的研究。例如,TabR的检索性质为通过上下文对象的影响来解释模型的预测提供了新的机会。
- 此外,TabR可以更好地支持持续学习(我们在5.2小节中触及了这个方向的表面)。
- 关于架构细节,可能的方向是改进相似性和价值模块,以及执行多轮检索和与检索到的实例的交互
附录
附录B:局限性和实际考虑
以下限制和实际考虑一般适用于检索增强模型。TabR本身并没有向这个列表添加任何新内容。
- 首先,对于给定的应用程序,应该从各个角度(业务逻辑、法律的考虑、道德方面等)仔细评估。使用真实的训练对象进行预测是否合理。
- 其次,根据应用程序,对于给定的目标对象,可能希望仅从可用数据的子集中检索,其中该子集是基于特定于应用程序的过滤器为目标对象动态形成的。根据3.1小节,它意味着Icand = Icand(x)<$Itrain。
- 第三,理想情况下,训练期间的检索应该模拟部署期间的检索,否则,基于检索的模型可能导致(高度)次优性能。
- 对于时间序列,在训练期间,必须允许TabR仅从过去检索。此外,也许,这种“过去”也应该受到限制,以防止从太旧的数据和太新的数据中检索。应该根据领域专业知识和业务逻辑做出决策。
- 让我们考虑一个任务,其中所有训练对象中有一些“相关对象”。例如,当将排名问题作为逐点回归来解决时,可以获得这样的“相关对象”作为对应于相同查询但不同文档的查询-文档对。在某些情况下,在训练期间,对于给定的目标对象,从“相关对象”检索可能是不公平的,因为对于在可用数据中没有“相关对象”的新对象,在生产中不可能这样做。同样,应该根据领域专业知识和业务逻辑做出设计决策。
- 最后,虽然TabR比现有的基于检索的表格DL模型更有效,但与纯参数模型相比,检索模块R仍然会导致开销,因此TabR可能无法按原样扩展到真正的大型数据集。我们在5.1小节中展示了一个将TabR扩展到更大数据集的简单技巧。我们将在A.4小节中更详细地讨论效率方面。