NBDT补充

前面提到的是对论文的浅析,并没有很好的把实现的细节讲出来。

在代码复现时,发现很多细节问题。

参考咸鱼大神的文章:

神经支持决策树(NBDT)算法研究_闲鱼技术的博客-CSDN博客

整个框架搭建分成四个步骤

将一个完整NN(全连接层及以上)和决策树结合,决策树的节点都有语义信息和向量信息

预训练一个CNN模型,并拿CNN最后一层的权重作为“每种类别”的隐向量

      比如先拿cifar10(一个图片分类数据集,有“猫”、“狗”之类的10种类别)训练一个resnet18的CNN。这类CNN的最后一层通常是全连接层(Fully Connected layer, FC),设倒数第二层输出的向量维度为d,则该全连接层W的维度为d,那么W的每一个列向量正好对应了每一个类别,可以将其视作每一种类别的隐向量。这种做法有点类似于Word2Vec。

总结

      利用上一个预训练网络的全连接层(FC)的权重W,每一列对应于一个类别。(这里注意维度的问题,FC层输出和FC的前一层的输出维度一致,因为FC是1*1)。

    所以可以构建叶子节点。 既可以利用已知标签给叶子 语义信息,又可以利用权重矩阵的列赋予叶子节点 特征向量(隐向量)。

利用类别的隐向量做层次聚类(hierarchical clustering)

并利用wordnet形成层次树结构。

     论文中将该树结构称之为“诱导层级”(Induced Hierarchy)。具体地,首先对类别隐向量做层次聚类,源码中是直接调用sklearn模块的AgglomerativeClustering类实现。聚类的分层结构有了之后,带来了两个问题:(1)两个子节点可以被聚类算法聚到一起,子节点都表示一类实体,但它们的父节点并没有一个实体的描述。(2)假设两个子节点被聚到了一起,子节点都有隐向量,它们的父节点的隐向量该怎么表示?

     针对问题(1),作者使用了WordNet,一种包含名词之间上下位关系的词网络,python里面可以直接在nltk模块中导入wordnet模块调用。由于叶节点是存在实体描述的,比方说cifar10的10个类别,那么通过WordNet,可以找到两个叶节点“最邻近的共同祖先”,e.g. “猫”和“狗”在WordNet中可能最近的归属是都位于“哺乳动物”下,那么“哺乳动物”就被作为“猫”和“狗”的父节点。因此,可以按照层次聚类的结果,自底向上依次为父节点“命名”,直到只有一个根节点,这就形成了所谓的“诱导层级”。

     针对问题(2),作者使用了子节点隐向量的均值,来代表父节点的隐向量。

总结

    有了叶子节点,再尝试逐层向上构建父节点。 代码中提供了一个参数  n_clusters  ,表示每次利用几个节点向上“寻父”。

    寻父则需要利用wordnet 找 子节点之间最近的共同关系(语义信息),而且还要利用子节点 特征向量的均值(向量信息),最后能够构建完整的一棵树。

在总损失中加入诱导层级的分类损失,finetune模型

确定损失函数(原始  +  树监督(Soft / Hard))

      在诱导层级(树结构,下称DT)有了之后,完整的模型不再是CNN,而是CNN+DT。为了迫使模型对新样本的预测能够遵循树结构从根节点一路推断至叶节点,就需要在总损失中加入树结构的分类损失,并对模型做finetune。

      这里首先要理解完整模型预测所采用的方式。一个新的样本(一张图片)进来,首先要经过前面的CNN,在最后一层的全连接层W之前,CNN给该图片输出的是一个d维向量x。

        将x与W做矩阵乘法(实质上是与各列向量做内积),即得到该样本在各个类别的logits分布,如果再softmax则得到了概率分布(这是单独的NN做分类决策的方法)。

        由于W的各列向量代表着DT叶节点的隐向量,那么完全可以用该DT来替换W,不再直接把x与W做矩阵乘法,而是从DT的根节点开始遍历,让x依次与DT各节点的子节点隐向量计算内积。这里遍历DT各节点有两种模式:“Hard”和“Soft”。以DT是二叉树为例,若是Hard模式,那么每次x会与左右两边的子节点分别算内积,哪边大就把x归为哪一边,一直计算到叶节点为止,最后x落到的叶节点,即为x所属的最终类别。若是Soft模式,则x会自顶向下遍历全部中间节点并计算内积,然后叶节点的最终概率是到达叶节点的路径上各中间节点的概率之乘积,最后通过比较各叶节点上的最终概率值的大小,即可确定x所属类别。

下面是其交叉熵损失函数的计算公式:

代码中权重更新的到底是树 还是  神经网络呢?

        Loss进行BP反向传播时优化的依然是CNN的网络权重,直观上理解:就是迫使前面CNN的输出能够符合后面DT的预期,尽可能使得样本按照DT的推断路径输出的预测类别符合其真实类别。

今天就补充到这里了,晚安~

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值