#第一性解释

可解释性终极追问,什么才是第一性解释?20篇CCF-A+ICLR论文给你答案

本文作者为张俊鹏、任启涵、张拳石,其中张俊鹏是张拳石老师的准入学博士生,任启涵是张拳石老师的博士生。

本文首先简单回顾了『等效交互可解释性理论体系』(20 篇 CCF-A 及 ICLR 论文),并在此基础上,严格推导并预测出神经网络在训练过程中其概念表征及其泛化性的动力学变化,即在某种程度上,我们可以解释在训练过程中神经网络在任意时间点的泛化性及其内在根因。

一、前言

长期以来,我们团队一直在思考可解释性领域的一个终极问题,即什么才是解释性领域的第一性原理?所谓第一性原理,目前没有一个被广泛接受的框架,世上本无路,我们需要逐渐去定义这样一个路。我们需要在一个新的理论体系中,提出大量的公理性要求,得出一个可以从不同的角度全方位精确严谨解释神经网络内在机理的理论。一套理论系统能严谨解释神经网络的方方面面才叫 “第一性原理”

如果你真的在严谨地做 “科学”,那么第一性原理一定不是想象中简单,而是一个复杂的体系,需要研究照顾到深度学习中方方面面纷繁复杂的现象。当然,如果你主观上不愿意或者不信一个理论需要足够严谨,那么研究会变得简单千万倍。就像物理学的标准模型一定比牛顿定律复杂,取决于你希望走哪条路。

沿着这个方向,我们团队独立从头构建了『等效交互可解释性理论体系』,并基于此理论,从三个角度来解释神经网络的内在机理。

1. 语义解释的理论基础:数学证明神经网络的决策逻辑是否可以被少量符号化逻辑所充分覆盖(充分解释)。『证明神经网络的决策逻辑是否可以被有限符号化逻辑解释清楚』这一命题是解释神经网络的根本命题。如果此命题被证伪,则从根本上讲,神经网络的可解释性将是无望的,所有的解释性算法只能提供近似的解读,而无法精确地覆盖所有的决策逻辑。幸运的是,我们找到了在大部分应用中神经网络都可以满足的面向遮挡鲁棒性的三个常见的条件,并且数学证明了满足这三个条件的神经网络的决策逻辑可以被写成符号化的交互概念。

参见 https://zhuanlan.zhihu.com/p/693747946

2. 寻找性能指标背后的可证明、可验证的根因:将神经网络泛化性和鲁棒性等终极性能指标的根因拆分具体少数细节逻辑。对神经网络性能(鲁棒性、泛化性)的解释是神经网络可解释性领域的另一个重大问题。然而,目前人们普遍认为神经网络性能是对神经网络整体的描述,而神经网络无法像人类一样将自己的分类判断拆解成具象化的、少量的决策逻辑。在这方面,我们给出了不一样的观点 —— 将性能指标与具象化的交互之间建立起数学关系。我们证明了 1. 等效交互的复杂度可以直接决定神经网络的对抗鲁棒性 / 迁移性,2. 交互的复杂度决定了神经网络的表征能力,3. 并解释神经网络的泛化能力 [1],和 4. 解释神经网络的表征瓶颈。


  • 参见1:https://zhuanlan.zhihu.com/p/369883667
  • 参见2:https://zhuanlan.zhihu.com/p/361686461
  • 参见3:https://zhuanlan.zhihu.com/p/704760363
  • 参见4:https://zhuanlan.zhihu.com/p/468569001

3. 统一工程性深度学习算法。由于缺少基础理论的支撑,目前深度学习算法大都是经验性的、工程性的。可解释性领域的第一性原理应该可以承担起将前人的大量工程性经验总结为科学规律的任务。在等效交互可解释性理论体系下,我们团队既证明了 14 种不同的输入重要性归因算法的计算本质在数学上都可以统一写成对交互作用的再分配形式。此外,我们还统一了 12 种提升对抗迁移性的算法,证明了所有提升对抗迁移性算法的一个公共机理是降低对抗扰动之间的交互效用,实现了对神经网络可解释性方向大部分工程性算法的理论凝练。

  • 参见1:https://zhuanlan.zhihu.com/p/610774894
  • 参见2:https://zhuanlan.zhihu.com/p/546433296

在等效交互可解释性理论体系下,我们的团队在之前的研究中已经成功发表了 20 篇 CCF-A 类和机器学习顶级会议 ICLR 论文,我们已经从理论和实验上充分解答了上述问题。

二、本文研究概述

沿着上述理论框架,在这篇知乎文章中,我们希望精确解释出神经网络训练过程中泛化性的变化规律,具体地涉及两篇论文。

  • 1.Junpeng Zhang, Qing Li, Liang Lin, Quanshi Zhang,“Two-Phase Dynamics of Interactions Explains the Starting Point of a DNN Learning Over-Fitted Features”,in arXiv: 2405.10262
  • 2.Qihan Ren, Yang Xu, Junpeng Zhang, Yue Xin, Dongrui Liu, Quanshi Zhang,“Towards the Dynamics of a DNN Learning Symbolic Interactions” in arXiv:2407.19198

图 1:两阶段现象的示意图。在第一阶段,神经网络逐渐消除中高阶交互,学习低阶交互;在第二阶段,神经网络逐渐建模阶数不断增大的交互。当神经网络训练过程中测试损失和训练损失之间的 loss gap 开始增大时,神经网络恰好也进入训练的第二阶段。

我们希望在等效交互框架里提出新的理论,精确预测出神经网络每一个时间点上神经网络所学到的交互概念的数量、复杂度,以及泛化性变化的动力学规律(如图 1 所示)。具体地,我们希望证明出两方面结论。

第一,基于前人的证明(一个神经网络的决策逻辑可以被严格解构表示为几十个交互概念效用的和的形式),进一步严格推导出在整个训练过程中,神经网络所建模的交互效用的变化动力学过程 —— 即理论需精确预测出在不同训练阶段,神经网络所建模的交互概念的分布的变化 —— 推导出哪些交互会在哪个时间点上被学习到

第二,寻找充分的证据,证明所推导的交互复杂度的变化规律客观反映出神经网络在全训练周期中泛化性变化的规律

综上两点,我们希望具体彻底解释清楚神经网络的泛化性变化的内在根因。

与前人的关系:当然大家可能第一反应想到神经正切核(NTK)[2],但是神经正切核只是把参数的变化曲线解了出来,而没办法进一步深入到决策逻辑层面进行解释,没有将神经网络建模的概念表征与其泛化性的关系建立起来,对泛化性的分析依然停留在特征空间分析的层面,而没有在【符号化概念逻辑】与【泛化性】之间建立起严格的关系。

三、两大研究背景

误会 1:神经网络的第一性表征是『等效交互』,而不是神经网络的参数和结构。单纯从结构层面分析神经网络是人们对神经网络泛化根本表征的误解。目前大部分神经网络泛化性研究主要着眼于神经网络的结构、特征、以及数据。人们认为不同的神经网络结构就自然对应不同的函数,并自然展现出不同的性能。

但是,事实上,如图 2 所示,结构的区别只是神经网络表征的表面形式。除去有明显缺陷的对性能有明显影响的神经网络,所有其他可以实现 SOTA 性能的具有不同结构的神经网络往往都建模了相似的等效交互表征,即不同结构的高性能神经网络在等效交互表征上往往都是殊途同归的 [3, 4]。虽然神经网络其中层特征内部是复杂的混乱的,虽然不同神经网络所建模的特征向量大相径庭,虽然神经网络中单个神经元往往建模了相对比较混乱的语义(不是严格清晰的语义),但是神经网络作为一个整体,我们从理论上证明神经网络的所建模的交互关系是稀疏的符号化的(而不是特征的稀疏性,具体见 “四、交互的定义” 章节),而且面向相同任务的完全不同的神经网络往往建模了相似的交互关系。

图 2:不同结构的神经网络所建模的等效交互往往是殊途同归的。对于一个相同的输入句子,面向两个相同任务的两个完全不同的神经网络建模往往相似的交互。

由于不同神经网络的参数和训练样本不一样,两个神经网络中没有任何一个神经元在表征上具有严格的一一对应关系,且每一个神经元往往建模着不同语义的混合模式。相比之下,正如上段分析,神经网络所建模的交互表征实际上是不同神经网络表征中的不变量。因此,我们有理由认为神经网络根本表征是等效交互,而不是其载体(参数和训练样本),符号化交互表征可能代表了知识表征的第一性原理(被交互的稀疏性定理、无限拟合性定理、以及殊途同归现象所保证,见 “四、交互的定义” 章节,具体详细研究见下面知乎文章。

参见:https://zhuanlan.zhihu.com/p/633531725

误会 2:神经网络的泛化性问题是一个混合模型问题,而不是一个高维空间的向量。如图 3 所示,传统的泛化性分析总是假设单个样本整体是高维空间的一个点,实际上神经网络对单个样本的表征是 mixture model 的形式 —— 实际上通过大量不同的交互来表达。我们发现简单交互的泛化能力比复杂交互的泛化能力更强,所以不再适合用一个简单标量来笼统表示整个神经网络在不同样本上的泛化能力。相反,同一个神经网络在不同的样本上建模了不同复杂度的交互关系,而不同复杂度的交互往往对应着不同泛化能力。通常情况下,神经网络建模的高阶(复杂)的交互往往难以泛化到测试样本上(测试样本上不会触发相同的交互),代表过拟合表征,而神经网络建模的低阶(简单)交互往往代表泛化性较强的表征,具体详细研究见 [1]。

图 3:(a)传统的泛化性分析总是假设单个样本整体是高维空间的一个点。(b)实际上神经网络对单个样本的表征是 mixture model 的形式,神经网络在单个样本会建模简单交互(可泛化的交互)和复杂交互(不可泛化的交互)。

四、交互的定义

让我们考虑一个深度神经网络

和一个输入样本

,它包含

个输入变量,我们用集合

表示这些输入变量的全集。令

表示 DNN 在样本

上的一个标量输出。对于一个面向分类任务的神经网络,我们可以从不同角度来定义其标量输出。例如,对于多类别分类问题,

可以定义为

,也可以定义为 softmax 层之前该样本真实标签所对应的标量输出。这里,

表示真实标签的分类概率。这样,针对每个子集

,我们可以用下面公式来定义

中所有输入变量之间 “等效与交互” 和 “等效或交互”。

如图 4(a)所示,我们可以这样理解上述与或交互:我们可以认为与等效交互表示神经网络所编码的

内输入变量之间的 “与关系”。例如,给定一个输入句子

,神经网络可能会在

之间建模一个交互,使得

产生一个推动神经网络输出 “倾盆大雨” 的数值效用。如果

中的任何输入变量被遮挡,则该数值效用将从神经网络的输出中移除。类似地,等效或交互

表示神经网络所建模的内输入变量之间的 “或关系”。例如,给定一个输入句子

,只要

中的任意一个词出现,就会推动神经网络的输出负面情感分类。

神经网络所建模的等效交互满足 “理想概念” 的三条公理性准则,即无限拟合性、稀疏性、样本间迁移性。

  1. 无限拟合性:如图 4,5 所示,对于任意遮挡样本,神经网络在样本上的输出可以用不同交互概念的效用之和来拟合。即,我们可以构造出一个基于交互的 logical model,无论我们如何遮挡输入样本,这个 logical model 依然可精确拟合模型在此输入样本在任意遮挡状态下的输出值。
  2. 稀疏性:面向分类任务的神经网络往往只建模少量的显著交互概念,而大部分交互概念都是数值效用都接近于 0 的噪声。
  3. 样本间迁移性:交互在不同样本间是可迁移的,即神经网络在(同一类别的)不同样本上建模的显著交互概念往往有很大的重合。

图 4:神经网络的复杂的推理逻辑可以被基于少量交互的逻辑模型

准确拟合。每个交互都是衡量神经网络建模特定输入变量集合

之间非线性关系的度量指标。当且仅当集合中变量同时出现时才会触发与交互,并为输出贡献数值分数

,集合中任意变量出现时会触发或交互。

图 5:神经网络在任意的遮挡样本上的输出可以用不同交互概念的效用之和来拟合,即我们可以构造出一个基于交互的 logical model,无论我们如何遮挡输入样本,哪怕穷举个输入单元上种完全不同的遮挡方式,这个 logical model 依然可精确拟合模型在此输入样本在任意遮挡状态下的输出值。

五、新的发现与证明

5.1 发现神经网络在训练过程中交互变化的两阶段现象

在这篇知乎文章中,我们关注神经网络解释性领域的一个根本问题,即如何从一个解析分析的角度去严格预测出神经网络在训练过程中泛化能力的变化情况,并且精确的分析神经网络从欠拟合到过拟合的整个动态变化过程及其背后的根本原因

首先,我们将交互的阶数(复杂度)定义为交互中的输入变量的数量,

。我们团队之前的工作发现神经网络在某个特定样本所建模的 “与或交互” 的复杂度直接决定了神经网络在这个样本的泛化能力 [1],即神经网络建模的高阶的(大量输入单元之间的)“与或交互” 往往有较差的泛化能力,而低阶的(少量输入单元之间的)“与或交互” 具有较强的泛化能力。

因此,本篇研究的第一步是去预测出神经网络在训练过程中不同时间点所建模的不同阶 “与或交互” 的复杂度的一个解析解,即我们可以通过神经网络在不同时间点所建模的不同阶 “与或交互” 的分布去解释神经网络在不同阶段的泛化能力。交互的泛化能力的定义与神经网络整体泛化能力的定义请见 “5.2 神经网络所建模交互的阶数和其泛化能力的关系” 章节。

我们提出两个指标来表示不同阶(复杂度)的交互的强度的分布。具体来说,我们用

来衡量所有阶正显著交互的强度,用

来衡量所有

阶负显著交互的强度,其中

表示显著交互的集合,

表示显著交互的阈值。

图 6:从训练不同轮次的神经网络中提取的不同阶交互强度

。在不同数据集上、不同任务上训练的不同的神经网络的训练过程都存在两阶段现象。前两个选定时间点属于第一阶段,而后两个时间点属于第二阶段。恰恰在进入神经网络训练过程的第二阶段不久,神经网络的测试损失和训练损失之间的 loss gap 开始显著上升(见最后一列)。这表明神经网络训练的两阶段现象与模型 loss gap 的变化在时间上是 “对齐” 的。更多实验结果请参见论文。

如图 6 所示,神经网络的两阶段现象具体表现为:

  • 在神经训练训练之前,初始化的神经网络主要编码中阶交互,很少编码高阶和低阶交互,并且不同阶交互的分布看起来呈现 “纺锤形”。假设具有随机初始化参数的神经网络建模的是纯噪声,我们在 “5.4 理论证明两阶段现象” 章节证明了具有随机初始化参数的神经网络建模的不同阶的交互的分布呈现 “纺锤形”,即仅建模少量的低阶和高阶交互,大量建模中阶交互。
  • 在神经网络训练的第一阶段,神经网络编码的高阶和中阶交互的强度逐渐减弱,而低阶交互的强度逐渐增强。最终,高阶和中阶交互逐渐被消除,神经网络只编码低阶交互。
  • 在神经网络训练的第二阶段,神经网络在训练过程中编码的交互阶数(复杂度)逐渐增加。在逐渐学习更高复杂度的交互的过程中,神经网络过拟合的风险也在逐渐提高。

上述的两阶段现象广泛存在于不同结构的神经网络训练于不同任务上的不同数据集的训练过程中。我们在图像数据集(CIFAR-10 数据集、MNIST 数据集、CUB200-2011 数据集(使用从图片中裁剪出来的鸟类图像)和 Tiny-ImageNet 数据集)上训练了 VGG-11/13/16 和 AlexNet。我们在 SST-2 数据集上训练了用于情感语义分类 Bert-Medium/Tiny 模型,我们在 ShapeNet 数据集中训练 DGCNN 来分类的 3D 点云数据。上图显示了不同的神经网络在不同训练时期提取的不同阶的显著交互的分布。我们在这些神经网络的训练过程中都发现了两阶段现象,更多实验结果及细节请参考论文。

5.2 神经网络所建模交互的阶数和其泛化能力的关系

我们团队之前的工作已经发现了神经网络所建模交互的阶数和其泛化能力的关系,即高阶交互比低阶交互具有更差的泛化能力 [1]。某个具体交互的泛化性有清晰的定义 —— 如果一个交互同时在训练样本和测试样本中频繁的被神经网络所建模,则这个交互具有较好的泛化能力。在本篇知乎文章中,介绍了两个实验来证明高阶交互具有较差的泛化能力,低阶交互具有较强的泛化能力。

实验一:观察在不同数据集上训练的不同神经网络所建模的交互的泛化性。这里我们用被测试集所触发的交互的分布和被训练集所触发的交互的分布的 Jaccard 相似性来度量交互的泛化性。具体来说,给定一个包含

个输入变量的输入样本

,我们将从输入样本提取到的

阶交互向量化

,其中

表示

个阶交互。然后,我们计算分类任务中所有类别为

的样本中提取到的阶的平均交互向量,表示为

,其中

表示类别为的样本的集合。接下来,我们计算从训练样本中提取的阶的平均交互向量

与从测试样本中提取的阶的平均交互向量

之间的 Jaccard 相似性,以衡量分类任务中类别为的样本的阶交互的泛化能力,即:

其中,

将两个

维交互向量投影到两个

维的非负向量上,以便计算 Jaccard 相似性。对于某一阶的交互,如果此阶交互普遍展现出较大的 Jaccard 相似性,则表示这一阶交互具有较强的泛化能力。

我们进行了实验计算不同阶交互

。我们测试了在 MNIST 数据集上训练的 LeNet、在 CIFAR-10 数据集上训练的 VGG-11、在 CUB200-2011 数据集上训练的 VGG-13,以及在 Tiny-ImageNet 数据集上训练的 AlexNet。为了减少计算成本,我们仅计算了前 10 个类别的 Jaccard 相似性的平均值

。如图 7 所示,随着交互阶数的增加,交互的 Jaccard 相似性不断下降。因此,这验证了高阶交互比低阶交互具有更差的泛化能力。

图 7:从训练样本和测试样本中提取的交互之间的 Jaccard 相似性。低阶交互具有相对较高 Jaccard 相似性表明低阶交互具有较强的泛化能力。

实验二:比较神经网络在正常样本和 OOD 样本建模的交互的分布。我们比较了从正常样本中提取的交互与从分布外 (OOD) 样本中提取的交互,以检查神经网络在 OOD 样本上是否建模更多的高阶交互。我们将少量训练样本的分类标签设置为错误标签。这样,数据集中的原始样本可以视为正常样本,而一些带有错误标签的样本则对应于 OOD 样本,这些 OOD 样本可能会导致神经网络的过拟合。我们在 MNIST 数据集和 CIFAR-10 数据集上分别训练了 VGG-11 和 VGG-13。图 8 比较了从正常样本中提取的交互的分布和从 OOD 样本中提取的交互的分布。我们发现,VGG-11 和 VGG-13 在分类 OOD 样本时建模了更多复杂的交互(高阶交互),而在分类正常样本时则使用了较低阶的交互。这验证了高阶交互的泛化能力通常弱于低阶交互。

图 8:比较从正常样本中提取的交互与从分布外 (OOD) 样本中提取的交互。神经网络通常在 OOD 样本上建模的更高阶的交互。

5.3 两阶段现象和神经网络训练过程 loss gap 的变化相对齐

我们发现上述两阶段现象可以充分表示神经网络泛化性动力学。一个很有趣的现象是神经网络训练过程中的两阶段现象和神经网络在测试集和训练集的 loss gap 的变化在时间上是对齐的。训练损失和测试损失之间的 loss gap 是衡量模型过拟合程度的最广泛使用的指标。图 6 显示了不同的神经网络在训练工程的测试损失和训练损失之间的 loss gap 的曲线,还显示了从不同训练时期的神经网络中提取的交互分布。我们发现当神经网络训练过程中测试损失和训练损失之间的 loss gap 开始增大时,神经网络恰好也进入训练的第二阶段。这表明神经网络训练的两阶段现象与模型 loss gap 的变化在时间上是 “对齐” 的。

我们可以这样理解上述现象:在训练过程开始前,初始化的神经网络所建模的交互全部表示随机噪声,并且不同阶交互的分布看起来像 “纺锤形”。在神经网络训练的第一阶段,神经网络逐渐消除中阶和高阶的交互,并学习最简单的(最低阶的)交互。然后,在神经网络训练的第二阶段,神经网络建模了阶数逐渐增大的交互。由于我们在 “5.2 神经网络所建模交互的阶数和其泛化能力的关系” 章节中的两个实验验证了高阶交互通常比低阶交互具有更差的泛化能力,因此我们可以认为在神经网络训练的第二阶段,DNN 首先学习了泛化能力最强的交互,然后逐渐转向更复杂但泛化能力较弱的交互。最终一些神经网络逐渐过拟合,并编码了大量中阶和高阶交互。

5.4 理论证明两阶段现象

理论证明神经网络训练过程的两阶段现象共分为三个部分,第一部分我们需要证明随机初始化的神经网络在训练过程开始之前建模的交互的分布呈现 “纺锤形”,即很少建模高阶和低阶交互,主要建模中阶交互。第二部分证明神经网络在训练的第二阶段在建模阶数逐渐增大的交互。第三部分证明神经网络在训练的第一阶段逐渐消除中阶和高阶交互,学习最低价的交互。

1. 证明初始化神经网络建模的 “纺锤形” 交互分布。

由于随机初始化的随机网络在训练过程开始之前建模的都是噪声,所以我们假设随机初始化的神经网络建模的交互的服从均值为

,方差为

的正态分布。在上述假设下,我们能够证明初始化的神经网络建模的交互的强度和的分布呈现 “纺锤形”,即很少建模高阶和低阶交互,主要建模中阶交互。

2. 证明神经网络训练的第二阶段的交互变化动态过程。

在进入正式的证明之前,我们需要做以下的预备工作。首先,我们参照 [5, 6] 的做法,将神经网络在特定样本上的 inference 改写为不同交互触发函数的加权和

其中,

为标量权重,满足

。而函数

为交互触发函数,在任意一个遮挡样本

上都满足

。函数的具体形式可以由泰勒展开推导得到,可参考论文,这里不做赘述。

根据上述改写形式,神经网络在特定样本上的学习可近似看成是对交互触发函数的权重

的学习。进一步地,实验室的前期工作 [3] 发现在同一任务上充分训练的不同的神经网络往往会建模相似的交互,所以我们可以将神经网络的学习看成是对一系列潜在的 ground truth 交互的拟合。由此,神经网络在训练到收敛时建模的交互可以看成是最小化下面的目标函数时得到的解:

其中

表示神经网络需要拟合的一系列潜在的 ground truth 交互。

则分别表示将所有权重拼起来得到的向量和将所有交互触发函数的值拼起来得到的向量。

可惜的是,上述建模虽然能得到神经网络训练到收敛时的交互,但是无法很好地刻画神经网络训练过程中学习交互的动态过程。这里引入我们的核心假设:我们假设初始化神经网络的参数上包含了大量噪声,而这些噪声的量级在训练过程中逐步变小。而进一步地,参数上的噪声会导致交互触发函数

上的噪声,且该噪声随着交互阶数指数级增长 (在 [5] 中已有实验上的观察和验证) 。我们将有噪声下的神经网络的学习建模如下:

其中噪声

满足

。且随着训练进行,噪声的方差

逐渐变小。

在给定的噪声量级的情况下最小化上述损失函数,可得到最优交互权重的解析解,如下图中的定理所示。

我们发现,随着训练进行(即噪声量级变小),中低阶交互强度和高阶交互强度的比值逐渐减小(如下面的定理所示)。这解释了训练的第二阶段中神经网络逐渐学到更加高阶的交互的现象。

另外,我们对上述结论进一步做了实验验证。给定一个具有 n 个输入单元的样本,指标

,其中

, 可以用来近似测量第 k 阶交互和第 k+1 阶交互强度的比值。在下图中,我们可以发现,在不同的输入单元个数 n 和不同的阶数 k 下,该比值都会随着的减小而逐渐减小。

图 9:在不同的输入单元个数 n 和不同的阶数 k 下,第 k 阶交互和第 k+1 阶交互强度的比值都会随着噪声量级的减小而逐渐减小。这说明随着训练进行(即逐渐变小),低阶交互强度与高阶交互强度的比值逐渐变小,神经网络逐渐学到更加高阶的交互。

最后,我们对比了在不同噪声量级下的理论交互值

在各个阶数上的分布

和实际训练过程中各阶交互的分布

,发现理论交互分布可以很好地预测实际训练中各时间点的交互强度分布。

图 10:比较理论交互分布

(蓝色直方图)和实际交互分布

(橙色直方图)。在训练第二阶段的不同时间点,理论交互分布都可以很好地预测和匹配实际交互的分布。更多结果请参见论文。天皓智联 开发板商城

3. 证明神经网络训练的第一阶段的交互变化动态过程。

如果说训练的第二阶段中交互的动态变化可以解释为权重的最优解在噪声逐渐减小时的变化,那么第一阶段就可认为是交互从初始化的随机交互逐渐收敛到最优解的过程。

路漫漫其修远兮,我们团队是做神经网络可解释性的第一性原理,我们希望在更多的方面把这个理论做扎实,能够严格证明等效交互是符号化的解释,并且能够解释神经网络的泛化性、鲁棒性,同时证明神经网络表征瓶颈,统一 12 种提升神经网络对抗迁移性的方法和解释 14 种重要性估计方法。我们后面会做出更扎实的工作,进一步完善理论体系

#Tora

阿里「轨迹可控版Sora」,告别「抽卡」,让视频生成更符合物理规律

你规定路线,Tora 来生成相应轨迹的视频。

目前,扩散模型能够生成多样化且高质量的图像或视频。此前,视频扩散模型采用 U-Net 架构 ,主要侧重于合成有限时长(通常约为两秒)的视频,并且分辨率和纵横比受到固定限制。

Sora 的出现打破了这一限制,其采用 Diffusion Transformer(DiT)架构,不仅擅长制作 10 到 60 秒的高质量视频,而且还因其生成不同分辨率、各种纵横比、且遵守实际物理定律的能力而脱颖而出。

可以说 Sora 是 DiT 架构最有利的证明,然而,基于 Transformer 的扩散模型在有效生成可控动作视频方面还未被充分探索。

针对这一问题,来自阿里的研究者提出了 Tora,这是第一个面向轨迹的 DiT 架构,它将文本、视觉和轨迹条件同时集成在一起以生成视频。

Tora 的设计与 DiT 的可扩展性无缝契合,允许精确控制具有不同持续时间、宽高比和分辨率的视频内容。大量实验证明,Tora 在实现高运动保真度方面表现出色,同时还能细致模拟物理世界的运动。

  • 论文地址:https://arxiv.org/pdf/2407.21705
  • 论文主页:https://ali-videoai.github.io/tora_video/
  • 论文标题:Tora: Trajectory-oriented Diffusion Transformer for Video Generation

一艘老式的木制帆船沿着规定好的路线在迷雾笼罩的河流上平稳地滑行,周围是茂密的绿色森林。

一条鲫鱼优雅地游过火星的红色岩石表面,鱼的轨迹向左,火星的轨迹向右。

热气球沿着不同的轨迹升入夜空,一个沿着规定的斜线,另一个沿着有弯度的轨迹。

两只可爱的小猫并排走在宁静的金色沙滩上。

气泡沿着轨迹轻轻地漂浮在盛开的野花中。

枫叶在清澈的湖面上颤动,映照着秋天的森林。

山间的瀑布倾泻而下,主题、背景的运动都可以按照不同的路线运动。

在 Tora 与其他方法的比较中,可以看出 Tora 生成的视频流畅度更高,更遵循轨迹,且物体不会存在变形的问题,保真度更好。

方法介绍

Tora 采用 OpenSora 作为其 DiT 架构的基础模型,包含一个轨迹提取器 (TE,Trajectory Extractor)、时空 DiT(Spatial-Temporal DiT )和一个运动引导融合器 (MGF,Motion-guidance Fuser) 。TE 使用 3D 视频压缩网络将任意轨迹编码为分层时空运动 patch。MGF 将运动 patch 集成到 DiT 块中,以生成遵循轨迹的一致视频。图 3 概述了 Tora 的工作流程。

时空 DiT(ST-DiT)

ST-DiT 架构包含两种不同的块类型:空间 DiT 块 (S-DiT-B) 和时间 DiT 块 (T-DiT-B),它们交替排列。S-DiT-B 包含两个注意力层,每个层按顺序执行空间自注意力 (SSA) 和交叉注意力,后面跟着一个逐点前馈层,用于连接相邻的 T-DiT-B 块。T-DiT-B 仅通过用时间自注意力 (TSA) 替换 SSA 来修改此架构,从而保持架构一致性。在每个块中,输入在经过规范化后,通过跳跃连接连接回块的输出。通过利用处理可变长度序列的能力,去噪 ST-DiT 可以处理可变持续时间的视频。

轨迹提取器

轨迹已被证明是一种更加用户友好的方法来控制生成视频的运动。然而,DiT 模型采用视频自编码器和 patch 化过程将视频转换为视频 patch。在这里,每个 patch 都是跨多个帧导出,因此直接采用帧间偏移是不合适的。为了解决这个问题,本文提出的 TE 将轨迹转换为运动 patch,运动 patch 与视频 patch 位于相同的潜在空间。

运动引导融合器

为了将基于 DiT 的视频生成与轨迹结合起来,本文探索了三种融合架构变体,将运动 patch 注入每个 ST-DiT 块。这些设计如图 4 所示。

实验结果

在实现细节上,研究者基于 OpenSora v1.2 权重来训练 Tora。训练视频的分辨率由 144p 到 720p 不等。为了平衡训练 FLOP 以及每次迭代不同分辨率和帧数所需的内存,研究者相应地将批大小从 1 调整到 25。

至于训练基础设施,研究者使用了 4 块英伟达 A100 和 Adam 优化器,学习率为 2 × 10^−5。

研究者将 Tora 与流行的运动指导视频生成方法进行了比较。评估中使用了三种设置,分别为 16、64 和 128 帧,所有设置都是 512×512 的分辨率。

结果如下表 1 所示,在 U-Net 方法常用的 16 帧设置下,MotionCtrl 和 DragNUWA 能够更好地与所提供的轨迹实现对齐,但仍弱于 Tora。随着帧数增加,U-Net 方法在某些帧中出现明显偏差,并且错位误差传播会导致后续序列中出现变形、运动模糊或物体消失。

相比之下,得益于集成了 Transformer 的缩放能力,Tora 对帧数变化表现出很高的稳健性。Tora 产生的运动更加流畅,且更符合物理世界。对于 128 帧测试设置下的评估,Tora 的轨迹精度达到其他方法的 3 到 5 倍,展现出了卓越的运动控制能力。

在下图 5 中,研究者对不同分辨率和持续时长的轨迹误差进行分析。结果显示,不同于 U-Net 随时间推移出现明显的轨迹误差,Tora 的轨迹误差随时间推移出现渐进增加。这与 DiT 模型中视频质量随时间增加而下降相一致。Tora 在更长的时间下保持了有效的轨迹控制。

下图 6 展示了 Tora 与主流运动控制方法的比较分析,在包含两人共同运动的场景中,所有方法都能生成相对准确的运动轨迹。不过,Tora 的视觉质量更好,这要归功于更长序列帧的使用,有助于实现更平滑的运动轨迹和更逼真的背景渲染。

可以看到,在 Tora 生成的自行车场景中,人的双腿表现出逼真的踩踏动作,而 DragNUWA 的双腿几乎水平漂浮,违反了物理真实性。此外,DragNUWA 和 MotionCtrl 在视频结尾处都出现了严重的运动模糊。

在另一个生成灯笼的场景中,DragNUWA 随着所提供轨迹的持续升降出现了严重的变形。MotionCtrl 的轨迹虽然相对准确,但生成的视频与两个灯笼的描述不相符。Tora 不仅严格地遵循了轨迹,而且最大程度地减少了物体变形,确保了更高保真度的动作表示。


#让循环语言模型超越Transformer++
小技巧大功效,「仅阅读两次提示」

在当前 AI 领域,大语言模型采用的主流架构是 Transformer。不过,随着 RWKV、Mamba 等架构的陆续问世,出现了一个很明显的趋势:在语言建模困惑度方面与 Transformer 较量的循环大语言模型正在快速进入人们的视线。

令人兴奋的是,这些架构在推理期间使用了恒定量的内存。不过,受制于有限的内存,循环语言模型(LM)无法记忆并使用长上下文中的所有信息,这导致了上下文学习(in-context learning,ICL)质量的不佳。因此,获得高效大语言模型的关键挑战在于选择存储或者丢弃哪些信息。

在最近的论文《Just read twice: closing the recall gap for recurrent language models》中,来自斯坦福大学、布法罗大学的研究者通过简单观察发现,数据在推理期间涌入循环语言模型的排序极大地影响了在有限内存中预测存储哪些信息的难度。

我们假设根据文档 D(比如伽利略・伽利莱的详细维基百科)来提问:伽利略是什么时候搬到的佛罗伦萨?这时,如果提示遵循了 [Q, D] 的排序,则模型只需要记住文档 D 中的一个事实即可。相反,如果提示遵循了 [D, Q] 的排序,则模型需要记住所有事实。如下图 1(左)所示。

因此,本文首先从理论上形式化了数据排序如何影响内存需求,然后提出两种方法来减轻对数据排序的依赖,分别是 Just-read-twice(JRT)提示策略和 JRT 循环架构。本文主要分为以下几个部分展开:

理解数据排序的作用。研究者得出的第一个洞见是:记忆问题的 hardness 要降低到与设置剥离(set disjointness,SD)相同,这是通信复杂度理论中持续数十年的最典型问题。SD 要求一种流算法(比如循环模型)来决定上下文中提供的输入集是否剥离:

理论分析和实验结果表明,第一个集 | A | 掌控了求解 SD 所需的内存。因果模型需要存储 A 中的所有元素以与 B 中的元素进行比较。这表明了,使用上下文中的「正确数据排序」(如将最小 min (|A|, |B|) 的集放在首位)将有助于内存受限的模型。更进一步,观察到上下文非因果逻辑的模型可在空间最小的 (|A|, |B|) 中求解 SD,而无需考虑数据排序。

其次是利用「正确的」排序。本文提出了一种非常简单的 JRT-Prompt 策略,在模型生成答案之前在上下文中将信息重复多次(如上图 1 右所示)。在第二以及更多轮次中,语言模型在决定存储哪些信息时要以完整的上下文为条件,从而有效避免了将数据排序「归正」的问题。

结果表明,JRT-Prompt 在 16 个已有循环语言模型和 6 项 ICL 任务上,实现了平均 11.0 ± 1.3 百分点的提升,而吞吐量是 FlashAttention-2(长度 32k、批大小 16)的 11.9 倍。JRT-Prompt 虽然增加了上下文长度,但渐进来看仍然比注意力更加地计算和内存高效。

超越因果模型。本文提出了 JRT-RNN,它的灵感来源于简单的 Prefix-LM 编码器解码器架构设计。大多数的上下文学习输入包含两部分内容,分别是输入的提示(上下文、指令)和作为输出的模型生成文本。在 Prefix-LM 架构中,LM 并没有遵循因果逻辑地处理提示区域,而对输出进行了因果解码,其中在因果区域仅使用了标准的下一个 token 预测损失,以及非因果区域上的损失。

不过遗憾的是,此前 Prefix-LM 模型的训练方法取得的成功有限,并使用了低效的 Transformer 主干。因此本文通过一些简单的改变来提高质量和效率,包括改进训练损失并使用称之为「Prefix Linear Attention,PLA」 的线性注意力公式。研究者发现,使用他们的 IO 感知实现,JRT-RNN 在 360m 和 1.3b 参数设置下,分别可以提供 13.7 和 6.9 百分点的平均质量改进,吞吐量是 FA2 的 19.2 倍。

  • 论文地址:https://arxiv.org/pdf/2407.05483
  • 项目主页:https://github.com/HazyResearch/prefix-linear-attention

JRT-Prompt 方法概览

上下文学习任务以 (C, Q, Y) 作为输入,其中 C 为一些上下文来源(如文档或代码存储库),Q 为给定上下文时对模型的一些问题或请求,Y 为答案。对于使用自回归 LM A 的标准上下文学习,研究者输入 C 和 Q,并根据正确的完成情况 Y 来评估生成的输出 Yˆ = A (C, Q)。

JRT-Prompt 是一种极其简单的方法,在提示模型输出答案之前会在上下文中重复提示中的信息(如问题和文档),例如下图 1 右的 Yˆ = A (C, Q, C, Q)。因此,在上下文第二次出现时,模型根据完整的上下文来决定存储哪些信息。

此外,JRT-Prompt 可以与现成的 LLM 一起使用。研究者在零样本提示下,在一系列记忆密集型上下文任务上评估了以下 LM:

  • Based 预训练 LM,参数规模为 1.3B,在 Pile 的 10 − 50B 个 token 上进行训练;
  • Mamba 预训练的 LM,参数规模为 130M、370M、1.4B 和 2.8B,在 Pile 的 300B 个 token 上进行训练;
  • Gated Linear Attention 预训练的 LM,参数规模为 1.3B 和 2.7B,在 SlimPajama 数据集的 100B 个 token 上进行训练;
  • Mamba-2 预训练的 LM,参数规模为 130M、370M、1.3B 和 2.7B,在 Pile 的 300B 个 token 上进行训练。

结果如下表 1 所示,通过增加状态(state)大小,研究者发现 JRT-Prompt 方法在各个模型和任务上平均带来了 11.0 ± 1.3 百分点的性能提升,利用该方法的 Based 模型平均优于利用标准提示的 Transformer 模型。

他们还发现,JRT-Prompt 可以使 Transformer 模型受益,并且该方法在一些任务上(附录 2)比少样本学习更加有效。值得注意的是,Springer 等人在论文《Repetition improves language model embeddings》中提出使用自回归 Transformer 模型来重复上下文以实现生成嵌入的目的,本文的研究结果也类似。研究者专注于亚二次架构和上下文学习任务。

JRT-Prompt 虽然由于重复而增加了上下文长度,但是其使用的亚二次循环架构仍比使用二次 Transformer 模型更高效。研究者发现,在序列长度 N = 32768、批大小为 16 时,使用 JRT-Prompt(序列长度 2N)在英伟达 H100 上提供的吞吐量是 FlashAttention-2(序列长度 N)的 11.9 倍。

JRT-RNN:编码器 - 解码器循环架构

JRT-RNN 的灵感来自于 Prefix-LMs,但侧重于扩展质量 - 效率权衡空间的帕累托边界(Pareto frontier)。为了提高质量,JRT-RNN 在编码器端使用了单独的 k_e 和 v_e 映射,在解码器端使用了 k_d 和 v_d 映射。虽然 Prefix LM 模型对编码器和解码器区域使用了共享映射权重,但研究者发现使用两组映射可以提高质量。

为了提高效率,JRT-RNN 为编码器使用了非因果线性注意力,而为解码器使用标准因果线性注意力。研究者称为 Prefix Linear Attention(PLA)(图 1 右),公式如下:

JRT-RNN 训练目标。Prefix LMs 通常不计算非因果区域的损失,而 JRT-RNN 将下一个 token 预测与掩码语言建模(MLM)目标进行了结合。并且对于添加的 MLM 目标,研究者用一个 [MASK] token 替换了来自编码器区域 {u_1, ..., u_M} 的比例为 P 的 tokens,并在预测原始 token 时测量了交叉熵损失。

损失如下:

实验结果

在实验中,研究者评估了 JRT-RNN 在以下三个指标上的质量和效率:

  • 上下文学习质量
  • 整体语言建模
  • 生成

上下文学习质量

如下表 2 所示,研究者发现,JRT-RNN 在参数为 360M(30B tokens)时比仅解码器的基线(Based)平均高出 13.7 个百分点,在参数为 1.3B(50B tokens)时平均高出 6.9 个百分点。

同时,JRT-RNN 在参数为 360M 和 1.3B 时与 Transformer++ 的差距分别缩小到了 0.5 个百分点和 1.9 个百分点之内。

在下表 3 中,研究者比较了当 prefill 长度 l 小于编码器长度 M 时,JRT-RNN 与同类推理策略的表现。

整体自然语言理解

根据以往研究,研究者进一步将困惑度分为了两组:联想记忆「AR slice」包括了被称为「AR hits」的 tokens,它们需要模型按照顺序执行记忆以正确地预测下一个 token;而「Other slice」包含剩余的 tokens(如记忆的知识)。

对于记忆频率,JRT-RNN 在「AR slice」表现出色。对于训练期间不常见的二元组(即不太可能在模型参数中被记住的),JRT-RNN 的困惑度相对于 Based 和 Mamba 这两个强大的因果循环基线有所改善。

对于记忆距离,在「AR slice」中,JRT-RNN 与仅解码器基线之间的差距随着上下文中重复二元组的增加而扩大。这也进一步证明了 JRT-RNN 可以帮助完成更长的上下文记忆任务。

非记忆频率。对于训练期间很少见到的二元组的非记忆「Other slice」,JRT-RNN 的困惑度比仅解码器的 LM 更差。这是意料之中的结果,因为 JRT-RNN 计算了仅解码器 LM 的 65% tokens 的损失。

我们预计这一差距会随着规模和训练时间的延长而缩小(随着二元语法频率的增加而增加)(图 3,左上角)。

生成吞吐量

生成可以分解为提示「prefill 处理」和解码「下一个 token 预测」两步。相较于标准的仅解码器循环模型,JRT-RNN 不会修改解码步骤,因此讨论重点在 prefill 阶段。

使用 Simran Arora 等人论文《Simple linear attention language models balance the recall-throughput tradeof》中提出的 Based CUDAn 内核,JRT-Prompt 在处理 prefill 时吞吐量分别是 FlashAttention-2 和 FLA Triton 内核的 11.9 和 13.7 倍,如下表 5 所示。

当研究者将批大小增加到 64 时,JRT-Prompt 吞吐量分别是 FlashAttention-2 和 FLA Triton 内核的 6.1 倍和 7.2 倍。

接下来他们扩展了 Based 内核以支持 JRT-RNN,并且证明了当将序列长度增加到 32768 时,吞吐量分别是 FlashAttention-2 和 FLA 的 19.2 倍和 22.0 倍。当将批大小增加到 64 时,JRT-RNN 分别又提供了 9.7 倍和 11.5 倍的吞吐量提升。JRT-RNN 所需的时间是 Based prefill 的 1.24 倍,比 JRT-Prompt 更加高效。