TabNet论文笔记

1.简介

本文根据2020年《TabNet: Attentive Interpretable Tabular Learning》翻译总结的。TabNet,一个注意力的可解释的表格学习方法。

XGBoost和LightGBM近几年在表格数据处理上占据了统治地位,是基于梯度提升决策树(GBDT)的,不是DNN(deep neutral network)。DNN在处理表格数据方面一直没有较大的进展。

TabNet使用了DNN,实验结果超过了XGBoost和LightGBM。主要是准确率方面。我实践中发现TabNet的速度确实慢很多,慢几倍吧。

首先说决策树为什么会占据表格数据处理的统治地位:(1)表格数据中存在近似超平面边界,决策树处理起来非常有效率,表现也很好;(2)决策树的解释性很强,可以有效的事后解释。

TabNet的贡献如下:
1)不像基于树的方法,TabNet输入未加工的数据,不需要任何特征处理。使用梯度下降进行训练。可以灵活的集成在端到端的学习。
2)TabNet在决策树的每一步,使用序列注意力来选择特征进行推理。具有可解释性,更好的进行学习。特征的选择是实例级别的(instance-wise),它可以对每个输入是不同的。TabNet采用了单个深度学习框架进行端到端学习。
3)TabNet超越或者与其他表格学习方法不相上下。它具有两种可解释性,一个是local 解释性,即展示了各输入特征的重要性和它们如何被组合在一起;另一个是global解释性,量化输出每个输入特征的贡献。
4)最后使用了非监督预训练来预测masked 特征,我们的经典CNN设计架构取得了显著的效果提升,我们是第一个对表格数据进行自监督学习的。

下图主要是对上面第2点的描述,特征选择的示例,可以看到分为两类集合特征,分别是职业(professional occupation)和投资(investments)。

在这里插入图片描述

2.TabNet进行表格数据学习

使用传统DNN模块构建决策树分类器:

决策树对于真实世界的表格数据学习是很成功的。可以使用传统的DNN模块来实施类似的决策树,如下图。独立特征的选择是在超平面形式中获取决策边界的关键。或者说,对特征的线性组合,其中每个成分系数决定着每个特征的比例。TabNet就是基于如此的像树的函数。(1)使用稀疏的实例级别的特征选择;(2)构建一个序列的(按顺序的)多步结构,基于选择的特征,每一决策步骤可以贡献决策的一部分;(3)通过选择特征的非线性处理,来改善学习能力;(4)凭借高维和多步骤,模仿集成处理。
在这里插入图片描述

TabNet结构:

表格数据一般分为数字型特征和分类型特征。我们使用原始的数字型特征,将分类型特征换成可以训练的embedding(词嵌入)。我们不考虑任何全局的归一化特征,仅使用BN(batch normalization)。
TabNet结构如下图,一个encoder和一个decoder,其中encoder里有feature transformer、attentive transformer;而decoder只有feature transformer。
在这里插入图片描述

特征选择(attentive transformer):

特征维度:D
Batch size:B
在每个决策步骤通过相同的D维特征f∈R^(B*D)
决策步骤数N_steps
第i步输入上一步(i-1)的处理信息,决定使用哪个特征,并且输出的特征代表会被集成到整体决策中。

可学习的mask M[i]∈R^(B*D),用来作为显著特征的软选择(soft selection)。通过大部分显著特征的稀疏选择,每一个决策步骤的学习不相关特征的能力不会被浪费掉。

Mask采用乘积的形式 M[i]*f
基于上一步a[i-1]处理的特征信息,采用attentive transformer 来回去mask,如下所示:
在这里插入图片描述

其中 在这里插入图片描述

h_i是一个可以训练的函数,如上图attentive transformer的部分,包括一个FC层、BN。
P[i]是上图的“prior scale”。表示一个特征有多少是被先前使用的。
在这里插入图片描述

上式的γ是一个放松参数。当γ=1时,一个特征会被仅在一个决策步骤中使用;当γ增加时,可以在多个决策步骤中使用一个特征。

稀疏(sparsity)方程如下:
在这里插入图片描述

特征处理(feature transformer):

如架构图中的feature transformer,分为两部分,shared across和step-dependent。shared across是在所有决策步骤中。

3.自监督学习

主要是TabNet架构图中的decoder。
在这里插入图片描述

4.实验结果

4.1 实际数据表现

Forest cover type:森林覆盖类型,是一个分类任务。从下表可以看出来TabNet好于XGBoost、LightGBM,和一些已知的神经网络模型AutoInt、AUTOML。
在这里插入图片描述

Rossmann store sales:预测商店销售,基于静态和随时间改变的特征。
在这里插入图片描述

4.2 可解释性

如下表,对成年人收入的特征重要性排名,TabNet也取得了和其他方法类似的效果。比如年龄是第一位。
在这里插入图片描述

4.3 自监督学习

可以看到非监督预训练(pre-training)显著改善了监督学习的效果,特别是当未标注的数据比标注数据多很多时。
在这里插入图片描述

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值