NN如何在表格数据中战胜GBDT类模型

TabNet: Attentive Interpretable Tabular Learning

v2-071e132c7353c2be34deb925be1bf617_b.jpg
作者:一元,公众号:炼丹笔记

背景

本文提出了一种高性能、可解释的规范深度表格数据学习结构TabNet。号称吊锤XGBoost和LightGBM等GBDT类模型。来吧,开学!

TabNet使用sequential的attention来选择在每个决策步骤中要推理的特征,使得学习被用于最显著的特征,从而实现可解释性和更有效的学习。我们证明了TabNet在广泛的非性能饱和表格数据集上优于其他变体,并产生了可解释的特征属性和对其全局行为的洞察。

最后,我们展示了表格数据的自监督学习,在未标记数据丰富的情况下显著提高了效果。

1. 决策树类模型在诸多的表格型问题中仍然具有非常大的优势:

  • 对于表格型数据中常见的具有近似超平面边界的决策流形,它们是表示有效的;
  • 它们的基本形式具有高度的可解释性(例如,通过跟踪决策节点),并且对于它们的集成形式有流行的事后可解释性方法;
  • 训练非常快;

2. DNN的优势:

  • 有效地编码多种数据类型,如图像和表格数据;
  • 减轻特征工程的需要,这是目前基于树的表格数据学习方法的一个关键方面;
  • 从流式数据中学习;
  • 端到端模型的表示学习,这使得许多有价值的应用场景能够实现,包括数据高效的域适配;

3. TabNet:

  • TabNet无需任何预处理即可输入原始表格数据,并使用基于梯度下降的优化方法进行训练,实现了端到端学习的灵活集成。
  • TabNet使用sequential attention来选择在每个决策步骤中从哪些特征中推理,从而实现可解释性和更好的学习,因为学习能力用于最显著的特征。这种特征选择是基于实例的,例如,对于每个输入,它可以是不同的,并且与其他基于实例的特征选择方法不同,TabNet采用了一种深度特征选择和推理的学习体系结构。
  • TabNet在不同领域的分类和回归问题的不同数据集上优于或等同于其他表格学习模型;
  • TabNet有两种可解释性:局部可解释性,用于可视化特征的重要性及其组合方式;全局可解释性,用于量化每个特征对训练模型的贡献。
  • 最后,对于表格数据,我们首次通过使用无监督预训练来预测掩蔽特征,得到了显著的性能提升;

类似于DTs的DNN building blocks

v2-9190a5eb0f6cb7eb9b7d72d329d9f34a_b.jpg
  • 使用从数据中学习的稀疏实例特征选择;
  • 构造一个连续的多步骤体系结构,其中每个步骤有助于基于所选特征的决策的一部分;
  • 通过对所选特征的非线性处理来提高学习能力;
  • 通过更高的维度和更多的步骤来模拟融合。

我们使用所有的原始数值特征并且将类别特征转化为可以训练的embedding,我们并不考虑全局特征normalization。

在每一轮我们将D维度的特征传入,其中B是batch size, TabNet的编码是基于序列化的多步处理, 有N个决策过程。在第i步我们输入第i-1步的处理信息来决定使用哪些特征,并且输出处理过的特征表示来集成到整体的决策。

v2-b6d1a9561c0302faeec27ce5a66a73d8_b.jpg

v2-abb11a24b628e59dd6dc4ab98dafb2ed_b.jpg

v2-1dc6dc2788688f089e859ba0889d400e_b.jpg

v2-43ae95d5234b6fe05df2b6ced6d0ede7_b.jpg

实验

1. 基于实例的特征选择

v2-785478097a5d36a43046c2a2446bc718_b.jpg
  • TabNet比所有其他的模型都要好;
  • TabNet的效果与全局特征选择非常接近,它可以找到哪些特征是全局最优的;
  • 删除冗余特征之后,TabNet提升了全局特征选择;

2. 现实数据集上的表现

v2-1918b0d1917213028a18238340882211_b.jpg

v2-49ecf78de4bfecdb81a66409f03fb597_b.jpg

v2-06de40fbd680d1ccf59da3314d4957ef_b.jpg

v2-8a1b816bea3a15fddd04fdd55793a6cf_b.jpg

v2-094402710b8c62a55c05d848390ca35c_b.jpg
  • TabNet在多个数据集上的效果都取得了最好的效果;

3. 自监督学习

v2-936d7d8cfded97f2c4d7830bab71c17d_b.jpg
  • 无监督预训练显著提高了有监督分类任务的性能,特别是在未标记数据集比标记数据集大得多的情况下;
  • 如上图所示,在无监督的预训练下,模型收敛更快。快速收敛有助于持续学习和领域适应.

小结

TabNet,一种新的用于表格学习的深度学习体系结构。TabNet使用一种顺序attention机制来选择语义上有意义的特征子集,以便在每个决策步骤中进行处理。基于实例的特征选择能够有效地进行学习,因为模型容量被充分地用于最显著的特征,并且通过选择模板的可视化产生更具解释性的决策。我们证明了TabNet在不同领域的表格数据集上的性能优于以前的工作。最后,我们展示了无监督预训练对于快速适应和提高模型的效果。

v2-071e132c7353c2be34deb925be1bf617_b.jpg
更多干货,请关注微信公众号:炼丹笔记

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值