背景
在闲鱼的很多业务场景中有大量需要利用算法进行分类的需求,例如图片分类、组件识别、商品分层、纠纷类别预测等。这些场景往往需要模型识别出的结果具备可解释性,也就是识别不能只得到其类别,最好能在识别过程中同时解释类别的层级和来源。如何进行有解释的图片分类成为了项目研发中的一个需求,基于此我对NBDT算法进行了调研。
NBDT 是UC伯克利和波斯顿大学最新(2020年4月)发的一篇paper中的模型。NBDT全称“Neural-Backed Decision Trees”,翻译为“神经支持决策树”,特别强调此处“B”不代表“Boosting”,以免熟悉GBDT的同学可能会误以为NBDT又是一种新型的梯度提升树模型。NBDT只是一颗决策树,而不是多棵树。
介绍
NBDT的特点在于它在决策树中(准确说是决策树前)融入了神经网络NN,这里NN通常是CNN即卷积神经网络。个人理解,NBDT的结构可以大致认为是“前面的CNN + 后面的DT”。DT=决策树。NBDT目前的使用场景是在图像分类领域。它的优势不在于准确率有多高,事实上在作者的实验中,它的准确率是略低于“前面的CNN”的。它的真正优势是能够很好的平衡模型准确率和模型解释性。具体来讲,它可以在略微牺牲CNN的准确率的前提下,取得比任何树模型都高的多的(分类)准确率,同时因为它融入了决策树,还可以显式的、逐级的给出模型推断的依据,也就是说,NBDT不但可以把一张狗的图片识别为“狗”,还可以告诉你它是如何一步一步识别的:比如,先把该图片以99.49%的概率识别为“动物”,再以99.63%的概率识别成“脊椎动物(Chordate)”,然后以99.4%的概率识别成脊椎动物下的“食肉动物(Carnivore)”,最后以99.88%的概率判断成食肉动物下的“狗”。这种推断方式无疑增强了模型的解释力。
图1 - 狗狗分类 (引用自官方Demo)
原理
NBDT采用了“预训练+finetune”的框架。整个流程大致分为以下三步:
预训练一个CNN模型,并拿CNN最后一层的权重作为“每种类别”的隐向量
比如先拿cifar10(一个图片分类数据集,有“猫”、“狗”之类的10种类别)训练一个resnet18的CNN。这类CNN的最后一层通常是全连接层(Fully Connected layer, FC),设倒数第二层输出的向量维度为d,则该全连接层W的维度为W,那么W的每一个列向量正好对应了每一个类别,可以将其视作每一种类别的隐向量。这种做法有点类似于Word2Vec。
利用类别的隐向量做层次聚类(hierarchical clustering)并利用wordnet形成层次树结构。
论文中将该树结构称之为“诱导