Frosst N, Hinton G. Distilling a neural network into a soft
decisiontree [J] . arXiv preprint arXiv:1711. 09784, 2017
摘要
深度神经网络已被证明是执行分类任务的一种非常有效的方法。 当输入数据为高维数据、输入输出关系复杂、标注训练样本数量大时,这种算法表现优异[Szegedy et al.,2015, Wu et al.,2016, Jozefowicz et al.,2016, Graves et al., 2013]。 但是很难解释为什么一个习得的网络会在一个特定的测试用例上做出一个特定的分类决定。 这是由于它们对分布式层次表示的依赖。 如果我们可以利用神经网络获得的知识,并将相同的知识表达在一个依赖于层次决策的模型中,解释一个特定的决策将会容易得多。我们描述了一种使用经过训练的神经网络来创建一种软决策树类型的方法,它比直接从训练数据中学习的决策树具有更好的泛化能力。
1介绍
深度神经网络优秀的泛化能力依赖于它们在隐藏层中使用分布式表示[LeCun et al.,2015],但这些表示难以理解。 对于第一个隐藏层,我们可以理解是什么导致了一个单位的激活,对于最后一个隐藏层,我们可以理解激活一个单位的效果, 但对于其他隐藏层来说,就有意义的变量(如输入和输出变量)而言,要理解特征激活的原因和影响就困难得多。 此外,隐含层中的单元将输入向量的表示分解为一组特征激活,这样,活跃特征的组合效应可以在下一个隐含层中产生适当的分布式表示。 这使得我们很难单独理解任何特定特征激活的功能角色,因为它的边际效应取决于同一层中所有其他单元的效果。
深度神经网络通过对训练数据的输入和输出关系中的大量的弱统计规律进行建模以做出可靠决定的特征进一步加剧了理解它的难度。同时,神经网络不能区分那些真正是数据的属性的弱统计规律和由训练集中的采样特性产生的伪规律。面对这些困难,似乎明智的做法是放弃试图通过理解单个隐藏单元的行为来理解深度神经网络如何做出分类决策。
相比之下,很容易解释决策树如何进行任何特定的分类,因为这取决于一个相对较短的决策序列,并且每项决策都直接基于输入的数据。然而,决策树通常不能像深度神经网络那样泛化。 与神经网络中的隐藏神经元相比,,一个典型的决策树低层节点只被很小一部分训练数据使用,因此往往会过拟合,除非训练集大小是树深度的指数级规模。
在本文中,我们提出了一种新的方法来解决泛化和可解释性之间的紧张关系。我们不是试图理解深度神经网络是如何做出决策的,而是使用深度神经网络来训练一棵决策树,它模仿神经网络发现的输入-输出函数,但以完全不同的方式工作。 如果有大量的未标记数据,神经网络可以创建一个更大的有标记的数据集来训练决策树,从而克服了决策树的统计效率低下的问题。即使无法获得无标签数据,也可以使用生成建模的最新进展[Goodfellow等人,2014,Kingma和Welling, 2013]从接近数据分布的分布中生成合成的无标签数据。在不使用未标记数据的情况下,仍然可以使用一种称为蒸馏的技术[Hinton et al.,2015, Buciluˇa et al.,2006]和一种软决策树,将神经网络的泛化能力转移到决策树中。
在测试阶段,我们使用决策树作为我们的模型。它的表现可能比神经网络略差,但通常会快得多,我们现在有了一个模型,我们可以直接解释和参与它的决策。
我们从描述我们使用的特定类型的决策树开始。 这一选择是为了便于将由深度神经网络获得的知识蒸馏成决策树。
2 The Hierarchical Mixture of Bigots
我们使用用小批数据梯度下降训练的软二叉决策树,其中每个内部节点i有一个习得的过滤器wi和一个偏置bi,每个叶子节点l有一个学习过的分布Ql。在每个内部节点,取最右分支的概率为
其中x是模型的输入,σ是sigmoid逻辑斯蒂函数。
这个模型是专家的层级混合(a hierarchical misture of experts)[Jordan and Jacobs, 1994]。但是每个专家都是偏执狂,他们在训练之后不看数据,因此总是产生相同的分布。该模型学习到了一套过滤器的层次结构,用来将每个例子以特定的路径概率分配给特定的专家,每个专家学习了一个在可能的输出类,k,上的一个简单的静态分布。
其中Ql表示第l个叶子的概率分布,每一个φl都是那个叶子上学到的参数。
图1:只有一个内部节点和两个叶子节点的软二叉树
为了避免在树中做出非常软的决定,我们在计算sigmoid之前向滤波器激活中引入了一个反温度β,因此在节点i处取右分支的概率为pi(x)=σ(β(xwi+bi))。
这个模型可以用两种不同的方式来给出类的预测分布,即使用具有最大路径概率的叶节点的分布,或以路径概率加权平均所有叶节点上的分布。如果我们取具有最大路径概率的叶节点的预测分布,那么对该预测的解释就是到叶节点的路径上所有过滤器的列表,以及二元激活决策。 如果我们以路径概率加权平均所有叶节点上的分布,我们发现模型在测试时达到略好的边缘精度,但这将导致模型在一个特定实例上的预测分布的解释的复杂性指数增长,因为它涉及到过滤器的所有节点。因此,在本文的剩余部分中,当我们引用模型的输出时,我们将采用具有最大路径概率的叶节点的分布。
我们使用一个损失函数来训练软决策树。该函数寻求最小化每对叶子之间的交叉熵,由其路径概率和目标分布加权。对于单个输入向量x,目标分布T的训练案例,损失为:
其中T是目标分布,Pl(x)是给定输入x最终到达叶子l的概率。
与大多数决策树不同,我们的软决策树使用的决策边界与输入向量的组件定义的轴不一致。 此外,它们的训练方式是首先选择树的大小,然后使用小批量梯度下降来同时更新所有参数,而不是更标准的贪心算法,即每次决定拆分一个节点[Friedman等人,2001]。
3 正则化
为了避免在训练期间陷入糟糕的解决方案中,我们引入了一个惩罚项,鼓励每个内部节点平等地使用左右子树。 如果没有这种惩罚,树就会陷入停滞状态,:一个或多个内部节点总是将几乎所有的概率分配给它的一个子树,并且这个决策的逻辑的梯度总是非常接近于零。 惩罚是两个子树的期望平均分布0.5,0.5和实际分布α,(1−α)之间的交叉熵,其中节点i的α由以下公式给出:
其中Pi(x)是从根节点到节点i的路径概率。惩罚在所有内部节点上叠加, 为:
λ是决定惩罚强度的超参数,在训练前设置。 这种惩罚是基于这样一种假设,即对可选子树进行同等使用的树通常更适合任何特定的分类任务,而且在实践中它确实提高了准确性。 然而,当一个人沿着树往下走的时候,这种假设就越来越不正确了; 树中的倒数第二个节点可能只负责两类输入,且比例不相等,在这种情况下惩罚非等分节点可能会损害模型的准确性。 我们发现,当惩罚的强度随着树中节点深度d呈指数衰减时,我们获得了更好的测试精度结果,因此惩罚的强度与2^ - d成正比。
沿着树向下,每个节点在任何给定的训练批中看到的数据的期望分数呈指数下降。这意味着使用这两个子树计算实际概率变得不那么精确。 为了解决这个问题,我们可以维护一个成指数下降的实际概率的运行平均值,以及一个时间窗口与节点深度成指数比例。 我们通过实验发现,通过使用惩罚强度随深度的指数衰减和用于计算运行平均值的时间窗口长度的指数增长,我们获得了更好的测试精度。
图2:这是在MNIST上训练的深度为4的软决策树的可视化。 内部节点上的图像是习得的过滤器,而叶子上的图像是习得的概率分布在类上的可视化。 最后每个叶子最可能的分类,以及其它可能的分类标注在每个边缘。 如果我们以最右边的内部节点为例,我们可以看到在树的那个层次上,潜在的分类只有3或8,因此习得的的过滤器只是学习区分这两个数字。 这个节点是一个过滤器,它查找将3的两端连接成8的两个区域是否存在。
5 解释软决策树如何分类
这项工作背后的主要动机是创建一个行为容易解释的模型; 为了完全理解为什么一个特定的示例被赋予了一个特定的分类,可以简单地检查在根节点和分类的叶节点之间的路径上所学习到的所有过滤器。这个模型的关键在于它不依赖于分层特征,而是依赖于分层决策。传统神经网络的分层特征允许它学习输入空间的健壮和新颖的表示,但过了一到两层,它们就变得难以接触。一些当前尝试解释神经网络依靠使用梯度下降法发现特别能激活特定的神经元的输入(Simonyan et al ., 2013年,Erhan et al ., 2009),但结果是复杂流形(manifold)上的一个点,这意味着其它输入也能得到同样的神经元激活,所以它不能反映整个流形。Ribeiro等人提出了一种策略,该策略依赖于拟合一些可解释的模型,该模型“在可解释组件存在/不存在下”,对输入空间中某些感兴趣的部分周围的神经网络的行为采取行动[Ribeiro等人,2016]。这是通过从输入空间采样和在感兴趣的区域周围查询模型,然后将一个可解释的模型拟合到模型的输出来完成的。 这避免了试图通过可视化流形上的单个点来解释特定输出的问题,但引入了输入空间中每个感兴趣区域都需要一个新的可解释模型的问题,并试图通过输入空间的一阶离散解释来解释模型行为的变化。通过依赖分层决策而不是分层特性,我们回避了这些问题,因为每个决策都是在读者可以直接参与的抽象级别上做出的。
7 结论
我们描述了一种方法,使用训练过的神经网络,以软决策树的形式,创建一个更易于解释的模型。软决策树通过随机梯度下降训练,使用神经网络的预测,以提供更多富含信息的目标。软决策树使用习得的过滤器基于输入示例做出分层决策,并最终选择类的特定的静态概率分布作为输出。这种软决策树的泛化性能优于直接对数据进行训练的决策树,但不如用于提供软目标训练的神经网络。如果有必要可以解释为什么一个模型把一个特定的测试用例以特定的方式分类,我们可以使用一个软决策树,但是我们仍然可以使用深度神经网络提高这个可解释模型的训练。