神经支持决策树(NBDT)算法研究

本文介绍了神经支持决策树(NBDT)算法,它结合了CNN和决策树,以提高模型的解释性。NBDT通过预训练的CNN获取特征,然后利用层次聚类和WordNet构建决策树结构。通过在总损失中加入树结构损失,模型在保持高准确率的同时提供了分类决策路径。实验表明,NBDT在精度上略低于原始CNN,但在可解释性方面优于其他模型。文章还解析了NBDT的源码,探讨了其在商品分类、图片识别等场景中的应用潜力。
摘要由CSDN通过智能技术生成

背景

在闲鱼的很多业务场景中有大量需要利用算法进行分类的需求,例如图片分类、组件识别、商品分层、纠纷类别预测等。这些场景往往需要模型识别出的结果具备可解释性,也就是识别不能只得到其类别,最好能在识别过程中同时解释类别的层级和来源。如何进行有解释的图片分类成为了项目研发中的一个需求,基于此我对NBDT算法进行了调研。

NBDT 是UC伯克利和波斯顿大学最新(2020年4月)发的一篇paper中的模型。NBDT全称“Neural-Backed Decision Trees”,翻译为“神经支持决策树”,特别强调此处“B”不代表“Boosting”,以免熟悉GBDT的同学可能会误以为NBDT又是一种新型的梯度提升树模型。NBDT只是一颗决策树,而不是多棵树。

介绍

NBDT的特点在于它在决策树中(准确说是决策树)融入了神经网络NN,这里NN通常是CNN即卷积神经网络。个人理解,NBDT的结构可以大致认为是“前面的CNN + 后面的DT”。DT=决策树。NBDT目前的使用场景是在图像分类领域。它的优势不在于准确率有多高,事实上在作者的实验中,它的准确率是略低于“前面的CNN”的。它的真正优势是能够很好的平衡模型准确率模型解释性。具体来讲,它可以在略微牺牲CNN的准确率的前提下,取得比任何树模型都高的多的(分类)准确率,同时因为它融入了决策树,还可以显式的、逐级的给出模型推断的依据,也就是说,NBDT不但可以把一张狗的图片识别为“狗”,还可以告诉你它是如何一步一步识别的:比如,先把该图片以99.49%的概率识别为“动物”,再以99.63%的概率识别成“脊椎动物(Chordate)”,然后以99.4%的概率识别成脊椎动物下的“食肉动物(Carnivore)”,最后以99.88%的概率判断成食肉动物下的“狗”。这种推断方式无疑增强了模型的解释力。

图1 - 狗狗分类 (引用自官方Demo)

原理

NBDT采用了“预训练+finetune”的框架。整个流程大致分为以下三步:

预训练一个CNN模型,并拿CNN最后一层的权重作为“每种类别”的隐向量

比如先拿cifar10(一个图片分类数据集,有“猫”、“狗”之类的10种类别)训练一个resnet18的CNN。这类CNN的最后一层通常是全连接层(Fully Connected layer, FC),设倒数第二层输出的向量维度为d,则该全连接层W的维度为W,那么W的每一个列向量正好对应了每一个类别,可以将其视作每一种类别的隐向量。这种做法有点类似于Word2Vec。

利用类别的隐向量做层次聚类(hierarchical clustering)并利用wordnet形成层次树结构。

论文中将该树结构称之为“诱导

评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值