HOPE: High-Order Polynomial Expansion of Black-Box Neural Networks (TPAMI 2024)
Paper https://ieeexplore.ieee.org/abstract/document/10528900
Project https://github.com/HarryPotterXTX/HOPE.git
1、引言
神经网络由于其强大的性能而被应用于各个领域,但其黑盒性质阻碍了其在医学、决策、工业等的发展。文章提出当一个神经网络的各个组件均为高阶可导时,它可以被展开为一个Taylor多项式,对其进行局部或全局解释。如果前向复合求神经网络的Taylor多项式,其计算量是非常大的。Taylor多项式所需要的也就只有在某点的各阶导数,因此只要我们能够求得神经网络精确的的高阶导数即可得到其Taylor展开式,将黑盒神经网络转换为一个显式的表达式。
2、复合函数的高阶求导法则
神经网络可以视为一个多层的复合函数
y
=
f
(
d
)
∘
f
(
d
−
1
)
∘
…
∘
f
(
1
)
(
x
)
\bold y=f^{(d)}\circ f^{(d-1)}\circ\ldots\circ f^{(1)}(\bold x)
y=f(d)∘f(d−1)∘…∘f(1)(x),其中
f
(
m
)
f^{(m)}
f(m)是神经网络的第m个模块。见下图,文章首先分析了复合函数的一般性高阶求导法则,然后将其应用于神经网络中,类似于梯度的反向传播,从最后一层到中间层,最后到输入层,逐步求得神经网络输出关于输入的各阶导数。
考虑一个复合函数
y
=
f
2
∘
f
1
(
x
)
\bold{y}=f_2\circ f_1(\bold{x})
y=f2∘f1(x), 其中
z
=
f
1
(
x
)
∈
R
s
\bold{z}=f_1(\bold{x})\in \mathbb{R}^s
z=f1(x)∈Rs为中间状态变量,
y
=
f
2
(
z
)
∈
R
o
\bold{y}=f_2(\bold{z})\in \mathbb{R}^o
y=f2(z)∈Ro 为输出变量。根据输入、中间和输出变量的维度,由简到难将其分为单输入-单状态-单输出(SISSSO),多输入-多状态-单输出(MIMSSO),多输入-多状态-多输出(MIMSMO)三种系统。先从最简单的SISSSO开始推导,由浅入深,方便理解更复杂更通用的其他情况。
2.1 SISSSO
首先,根据链式法则,我们可以得到
∂
k
y
∂
x
k
\frac{\partial^k\bold y}{\partial\bold x^k}
∂xk∂ky,
∂
k
y
∂
z
k
\frac{\partial^k\bold y}{\partial\bold z^k}
∂zk∂ky,
∂
k
z
∂
x
k
\frac{\partial^k\bold z}{\partial\bold x^k}
∂xk∂kz之间的关系。注意,如果碰到了
∂
k
y
∂
z
k
−
1
∂
x
\frac{\partial^k\bold y}{\partial\bold z^{k-1}\partial\bold x}
∂zk−1∂x∂ky,将其转换为
∂
z
∂
x
∂
k
y
∂
z
k
\frac{\partial\bold z}{\partial\bold x}\frac{\partial^k\bold y}{\partial\bold z^{k}}
∂x∂z∂zk∂ky的形式,我们需要将这种混合偏导数转换为非混合偏导数,方便解耦。
根据前三阶的导数推导,可以看到要想求
∂
k
y
∂
x
k
\frac{\partial^k\bold y}{\partial\bold x^k}
∂xk∂ky,只需要知道
∂
k
y
∂
z
k
\frac{\partial^k\bold y}{\partial\bold z^k}
∂zk∂ky,
∂
k
z
∂
x
k
\frac{\partial^k\bold z}{\partial\bold x^k}
∂xk∂kz。为了方便,我们将其重写为矩阵形式
简写为
在这个方程中
v
y
,
x
,
v
y
,
z
∈
R
n
\bold v^{y,x}, \bold v^{y,z} \in \mathbb{R}^n
vy,x,vy,z∈Rn 包含了
∂
k
y
∂
x
k
\frac{\partial^k\bold y}{\partial\bold x^k}
∂xk∂ky,
∂
k
y
∂
z
k
\frac{\partial^k\bold y}{\partial\bold z^k}
∂zk∂ky (k=1,…,n),
M
z
,
x
∈
R
n
×
n
\bold M^{z,x} \in \mathbb{R}^{n\times n}
Mz,x∈Rn×n 是一个由
∂
k
z
∂
x
k
\frac{\partial^k\bold z}{\partial\bold x^k}
∂xk∂kz所组成的转换矩阵,它是一个下三角矩阵。如果我们知道
f
1
f_1
f1和
f
2
f_2
f2的具体形式,
∂
k
y
∂
z
k
\frac{\partial^k\bold y}{\partial\bold z^k}
∂zk∂ky和
∂
k
z
∂
x
k
\frac{\partial^k\bold z}{\partial\bold x^k}
∂xk∂kz都是比较好求的,唯一的问题就是这个下三角矩阵该怎么求,只要我们知道下三角矩阵的每一个位置是
∂
k
z
∂
x
k
\frac{\partial^k\bold z}{\partial\bold x^k}
∂xk∂kz的什么组合,
∂
k
y
∂
x
k
\frac{\partial^k\bold y}{\partial\bold x^k}
∂xk∂ky的求解就变得非常简单。下面我们推导
M
z
,
x
\bold M^{z,x}
Mz,x的递推公式。
方程的第i项和第(i+1)项可以重写为
对公式5两边求导可以得到(i+1)项的另一种求法
进一步,由于是下三角矩阵,
M
i
,
0
z
,
x
=
0
\bold M^{z,x}_{i,0}=0
Mi,0z,x=0,
M
i
,
n
z
,
x
=
0
(
i
<
n
)
\bold M^{z,x}_{i,n}=0~(i < n)
Mi,nz,x=0 (i<n),方程7简化为
对比公式6和8,可以得到递推公式为
以上,首先确定需要计算几阶导数,通过公式9得到转换矩阵的具体形式,将
∂
k
z
∂
x
k
\frac{\partial^k\bold z}{\partial\bold x^k}
∂xk∂kz代入该矩阵,然后通过公式3中简单的矩阵乘法便可直接计算得到
∂
k
y
∂
x
k
\frac{\partial^k\bold y}{\partial\bold x^k}
∂xk∂ky。并且矩阵M的形式是通用的,只是代入的值不同,对于神经网络不同层,代入不同的
∂
k
z
∂
x
k
\frac{\partial^k\bold z}{\partial\bold x^k}
∂xk∂kz即可得到相应层的转换矩阵,计算消耗很小。
关于转换矩阵怎么求,补充材料中有简单的小例子帮助理解
2.2 MIMSSO
2.2.1 求解非混合偏微分
相较于SISSSO,MIMSSO的输入和中间状态变量的个数增加了,和上面的推导过程类似,前三阶导数为
为了之后推导方便,定义了新的算子,用来存储各阶导数
方程10就可以重写为以下形式
同样,我们将其转换为矩阵表示形式,得到
可以看到它和方程3的形式是一模一样的,只是乘法和次方变成了Hadamard积、Hadamard幂了。所以我们依旧可以使用方程9的递推公式,先得到矩阵的形式,然后代入
β
k
z
T
β
x
k
\frac{\beta^k\bold z^T}{\beta\bold x^k}
βxkβkzT即可,需要代入得只有各阶非混合偏微分,求解还是很简单的。
2.2.2 求解混合偏微分
神经网络的第一个模块很多都是线性层,比如全连接层的线性加权或者卷积层,只有一阶导数非零,其高阶偏导数均为0
前三阶混合偏导数为
可以简写为
其中
重写为矩阵形式,有
对于神经网络的其他层,我们都只需要根据2.1.1中的公式计算出其各阶非混合偏导数即可,只有反向传播到第一层时才需要执行以上操作计算混合偏导数。
2.3 MIMSMO
这是最宽泛的情况,输入、状态、输出变量个数不限。我们只需要对每一个输出均执行2.2节的公式即可。
3、神经网络的高阶求导法则
类似于反向传播,这里通过从输出层到中间层,最后到输入层不断迭代使用第2章得到的公式,计算得到每一层的高阶导数,计算消耗主要来源于每一层转换矩阵的计算。中间层只需要计算各阶非混合偏导数,仅在输入层计算混合偏导数。
3.1 输出单元
首先得有一个初始化的
v
7
\bold v_7
v7,由于
y
=
y
(
7
)
\bold y=\bold y^{(7)}
y=y(7),只有一阶导为1,高阶导为0,得到
3.2 全连接层
全连接层的输入输出关系为
这里分析
y
(
m
)
→
y
(
m
+
1
)
→
y
\bold y^{(m)}\to\bold y^{(m+1)}\to \bold y
y(m)→y(m+1)→y这一段,
y
(
m
)
\bold y^{(m)}
y(m)为输入,
y
(
m
+
1
)
\bold y^{(m+1)}
y(m+1)为状态变量,有
可以得到全连接层的转换矩阵为一个对角矩阵
其混合偏导转换矩阵为
3.3 卷积层
卷积层也可以视为稀疏权重的全连接层,但是如果将其转换为等价的全连接层然后再计算的话太费时了。这里在卷积层面重新推导了卷积层高阶求导公式。卷积层输入输出关系为
y
(
m
+
1
)
\bold y^{(m+1)}
y(m+1)的第u个输出是卷积核
F
(
m
+
1
)
\bold F^{(m+1)}
F(m+1)全部元素和
y
(
m
)
\bold y^{(m)}
y(m)部分元素的乘积,
其一阶和高阶导数为
如果我们想得到卷积的一阶导数,一般使用反卷积形式(补充材料有说明),公式34化为如下形式
对比公式34和35,自然也可以得到其高阶导数的反卷积计算公式,如下
3.4 非线性激活函数和池化层
非线性激活函数和池化层的推导大部分在补充材料。需要计算激活函数的高阶导数,主要讲讲Sigmoid、Tanh和GELU,和之前三角矩阵递推公式推导思想类似,将复杂问题转换为矩阵的求解问题。
3.4.1 Sigmoid高阶导数公式
3.4.2 Tanh高阶导数公式
3.4.3 GELU高阶导数公式
4、神经网络的高阶Taylor展开
只要根据上面的公式求得各阶导数,Taylor多项式自然而然就得到了。
这一章分析了一下上下界、收敛性、时间复杂度和全局可解释性。其中全局可解释章节计算不同参考点下的Taylor多项式,然后需要将所有多项式都转换到同一个参考点下比较各个项的系数,来得到一个全局解释,难点在于如何将不同参考点的多项式转换为同一参考点。这里将变量通过编码的形式解决。首先方程41可以重写为
我们对这个Taylor多项式再一次求它在某点的各阶导数,得到在另一个点上的等价的Taylor多项式,进而将其转换到另一个参考点上。
比如神经网络在x=1点展开得到的Taylor多项式为
y
=
(
x
−
1
)
+
2
(
x
−
1
)
2
y=(x-1)+2(x-1)^2
y=(x−1)+2(x−1)2,其在参考点x=2的输出及前两阶导数为
y
∣
x
=
2
=
3
y|_{x=2}=3
y∣x=2=3,
∂
y
∂
x
∣
x
=
2
=
5
\frac{\partial y}{\partial x}|_{x=2}=5
∂x∂y∣x=2=5,
∂
2
y
∂
x
2
∣
x
=
2
=
4
\frac{\partial^2 y}{\partial x^2}|_{x=2}=4
∂x2∂2y∣x=2=4,因此其在x=2点的等价Taylor多项式为
y
=
3
+
5
(
x
−
2
)
+
2
(
x
−
2
)
2
y=3+5(x-2)+2(x-2)^2
y=3+5(x−2)+2(x−2)2。可以验算,二者均为
y
=
1
−
3
x
+
2
x
2
y=1-3x+2x^2
y=1−3x+2x2。
公式51z中Taylor多项式的各阶导数为
其中
β
n
\beta_n
βn为已知系数。为了简化
∂
k
Δ
n
∂
x
i
1
…
∂
x
i
k
\frac{\partial^k \Delta_n}{\partial\bold x_{i_1}\ldots\partial\bold x_{i_k}}
∂xi1…∂xik∂kΔn的计算,在执行偏导操作之前首先对
Δ
n
\Delta_n
Δn 进行编码。具体来说,
a
0
Δ
x
1
a
1
Δ
x
2
a
2
…
Δ
x
p
a
p
a_0\Delta\bold x_1^{a_1}\Delta\bold x_2^{a_2}\ldots \Delta\bold x_p^{a_p}
a0Δx1a1Δx2a2…Δxpap编码为
[
a
0
,
a
1
,
…
,
a
p
]
[a_0,a_1,\ldots,a_p]
[a0,a1,…,ap],则其关于
x
i
\bold x_i
xi的偏导数
(
a
0
a
i
)
Δ
x
1
a
1
…
Δ
x
i
a
i
−
1
…
Δ
x
p
a
p
(a_0a_i)\Delta\bold x_1^{a_1}\ldots\Delta\bold x_i^{a_i-1}\ldots \Delta\bold x_p^{a_p}
(a0ai)Δx1a1…Δxiai−1…Δxpap的编码为
[
a
0
a
i
,
a
1
,
…
,
a
i
−
1
,
…
,
a
p
]
[a_0a_i,a_1,\ldots,a_i-1,\ldots,a_p]
[a0ai,a1,…,ai−1,…,ap]。比如,假设输入个数为2,要求2阶导数,有
如要求
Δ
n
\Delta_n
Δn关于
x
1
\bold x_1
x1的偏导数。只需要将第一列编码(系数)乘以第二列编码(
x
1
\bold x_1
x1的指数),第二列编码减1即可。
要求
∂
2
Δ
n
∂
x
1
∂
x
2
\frac{\partial^2 \Delta_n}{\partial\bold x_1\partial\bold x_2}
∂x1∂x2∂2Δn,在公式54编码基础上第一列乘以第三列(
x
2
\bold x_2
x2的指数),第三列减1即可。
将不同点得到的Taylor多项式均转换到同一个参考点下,对比系数即可得到一个全局解释。
5、实验
首先对方法本身性能进行了验证,在近似精度、运行效率、收敛性、上下界等方面进行了说明。对比的主要是Pytorch的Autograd自动求导模块。github上两种方式的Taylor展开都写了,可以玩玩。
应用上写了三个:函数发现、快速推理、特征提取
5.1、函数发现
使用INR对系统进行重现,然后使用HOPE对该网络进行展开,得到网络的显式表达式。
2D系统原函数如下:
在(0.0,0.0), (0.5,0.5), and (-0.5,-0.5)三个点得到的Taylor多项式如下,与原函数一致。
在900多个不同点上进行采样展开,并绘制各个系数的箱型图得到
更复杂一点的包含了一个潜在变量的5D函数:
不论是潜在变量不可观测,潜在变量部分可观测,或潜在变量完全可观测,得到的结论都与理论一致。
5.2、快速推理
设计了一个单水箱神经网络控制器,然后用HOPE对其展开并完全代替,其推理时间是能够明显提升的。
5.3、特征提取
训练了一个MNIST手写数字分类器,将其转化为10个等价的单输出模型,在某张参考图像上对其进行Taylor热图分析,HOPE得到的结果与基于扰动的基线一致,一阶热图因为忽略了高阶项所以并不相同,LRP加入一些人为因素,因此和基于扰动的很大不同。选基于扰动的方法作为基线是因为这是模型真实的输出变化,没有加入任何主观因素。在时间上,基于扰动的方法最久,其他方法相差不大。