引言: torch.einsum()的分析和介绍已经有很多博客介绍过了, 但大多数的落脚点都是爱因斯坦求和约定,许多篇幅是用于介绍爱因斯坦求和约定到的各项法则,而实际案例分析方面只是草草给出一笔带过,涉及到的案例也较为简单。而实际我们要用到或者看到torch.einsum()的时候往往是在计算非常复杂的情况下。
因此本文将从实际复杂案例的角度对torch.einsum()的计算过程进行分析,一步一步的推导最终输出的每个元素和输入元素之间的关系。
爱因斯坦求和约定
首先,torch.einsum()
的基础原理是爱因斯坦求和约定,此处为了行文的整体性将对其进行简要的介绍,如果只关注计算本身,可以跳到下一节。爱因斯坦求和约定是为了简化计算而诞生的一种“记法”,就类似于我们用
×
\times
×来标记乘法一样,不同之处在于爱因斯坦求和约定可表示的运算更为复杂、灵活性也更高。爱因斯坦求和约定的典型写法为:
i
1
i
2
.
.
.
i
N
,
j
1
j
2
.
.
.
j
M
→
i
k
1
i
k
2
.
.
j
l
1
j
l
1
,
k
1
.
.
.
∈
N
,
l
1
.
.
l
∈
M
i_1i_2...i_N,j_1j_2...j_M\rightarrow i_{k_1}i_{k_2}..j_{l_1}j_{l_1},k_1...\in N,l_1..l\in M
i1i2...iN,j1j2...jM→ik1ik2..jl1jl1,k1...∈N,l1..l∈M
其中左端
i
1
i
2
.
.
.
i
N
,
j
1
j
2
.
.
.
j
M
i_1i_2...i_N,j_1j_2...j_M
i1i2...iN,j1j2...jM就表示了输入两个矩阵元素的坐标索引,右端
i
k
1
i
k
2
.
.
j
l
1
j
l
1
i_{k_1}i_{k_2}..j_{l_1}j_{l_1}
ik1ik2..jl1jl1为输出矩阵元素的坐标索引,可以看到输出矩阵元素索引相较于输入端的索引可能会缺少几项,运算就是发生这几个维度上的乘累加操作。
其中同时出现在左端和右端的坐标索引为自由索引,只用于标记位置;而仅仅出现在右端的索引为求和索引,爱因斯坦求和约定的本质就是沿着求和索引的方向计算两个输入逐元素乘累加和的结果放到输出自由索引的位置上,更为细致的介绍参见:一文学会 Pytorch 中的 einsum和einsum:爱因斯坦求和约定
举例而言:
i
j
,
j
k
→
i
k
ij,jk\rightarrow ik
ij,jk→ik
就表示沿着
j
j
j这个维度进行乘累加操作:
O
i
k
=
∑
j
A
i
j
B
j
k
O_{ik}=\sum_{j}A_{ij}B_{jk}
Oik=j∑AijBjk
输出的第
(
i
,
k
)
(i,k)
(i,k)个元素为
A
i
⋅
A_{i \cdot}
Ai⋅的行向量和
B
⋅
k
B_{\cdot k}
B⋅k列向量逐元素乘累加,实际上就是矩阵相乘。
复杂案例推导
正如第一节中所介绍的,torch.enisum()的核心计算过程就是沿着只在算式右边出现的轴对输入矩阵元素进行乘累加得到对应位置的输出元素。因此,想要弄清一个复杂的torch.eisum()表达式含义需要做的也只是将这个求和公式写出来再仔细分析。
案例. 四维张量乘三维张量
给出一个复杂案例:
n
c
j
t
,
n
p
j
−
>
n
c
p
t
ncjt,npj->ncpt
ncjt,npj−>ncpt
则其输出元素可以表示为:
C
n
c
p
t
=
∑
j
A
n
c
j
t
B
n
p
j
C_{ncpt}=\sum_j A_{ncjt}B_{npj}
Cncpt=j∑AncjtBnpj
首先我们可以注意到对于C的第一维
n
n
n而言,它同时出现在A和B的首位,也就是对于这一维的每个元素,都是会逐元素的执行A和B剩余维度的计算再在当前维度上排布,用深度学习中的描述来说就是对BATCH中的每个元素都独立的执行后续子操作,子操作可以记为:
C
c
p
t
=
∑
j
A
c
j
t
B
p
j
C_{cpt}=\sum_j A_{cjt}B_{pj}
Ccpt=j∑AcjtBpj
紧接着,对当前算式的第一维
c
c
c来说它只出现在
A
A
A中,每沿着
c
c
c计算一个不同的元素都要和“相同”的B计算,也就出现了广播机制,B有了个隐藏的、元素重复的维度
c
c
c,计算变为
C
c
p
t
=
∑
j
A
c
j
t
B
c
p
j
C_{cpt}=\sum_j A_{cjt}B_{cpj}
Ccpt=∑jAcjtBcpj,同第一步计算的原理,这里又可以化简成逐元素的子操作:
C
p
t
=
∑
j
A
j
t
B
p
j
=
∑
j
B
p
j
A
j
t
C_{pt}=\sum_jA_{jt}B_{pj}=\sum_jB_{pj}A_{jt}
Cpt=j∑AjtBpj=j∑BpjAjt
此时易看出
(
p
,
t
)
(p,t)
(p,t)元素就是B的第
p
p
p行向量和A的第
t
t
t列向量求内积。
从而我们可以得出结论,这一表达式的意思是,对于BATCH内的每个元素(A‘三维,B’二维),对B在第一维度进行广播(A’‘三维,B’‘三维),最后沿着第二维和第三维计算矩阵相乘B’‘‘A’’‘(A’‘‘二维,B’’'二维)。
而整个的推导过程可以总结为以下几要点:
- 沿着维数较高输入的第一维开始,判断是否存在于B中,如果在的话就可认为是逐元素操作,暂时忽略该维度;
- 如果该维度指示不在B中,则进行广播操作,重新回到1,否则3.
- 判断当前最简表达式的意义。