导语:本文主要介绍了如何理解 PyTorch 中的爱因斯坦求和 (einsum) ,并结合实际例子讲解和 PyTorch C++实现代码解读,希望读者看完本文后掌握 einsum 的基本用法。
撰文|梁德澎
原文首发于公众号GiantpandaCV
1
爱因斯坦求和约定
爱因斯坦求和约定(einsum)提供了一套既简洁又优雅的规则,可实现包括但不限于:向量内积,向量外积,矩阵乘法,转置和张量收缩(tensor contraction)等张量操作,熟练运用 einsum 可以很方便地实现复杂的张量操作,而且不容易出错。
三条基本规则
首先看下 einsum 实现矩阵乘法的例子:
a = torch.rand(2,3)
b = torch.rand(3,4)
c = torch.einsum("ik,kj->ij", [a, b])
# 等价操作 torch.mm(a, b)
其中需要重点关注的是 einsum 的第一个参数 "ik,kj->ij",该字符串(下文以 equation 表示)表示了输入和输出张量的维度。equation 中的箭头左边表示输入张量,以逗号分割每个输入张量,箭头右边则表示输出张量。表示维度的字符只能是26个英文字母 'a' - 'z'。
而 einsum 的第二个参数表示实际的输入张量列表,其数量要与 equation 中的输入数量对应。同时对应每个张量的 子 equation 的字符个数要与张量的真实维度对应,比如 "ik,kj->ij" 表示输入和输出张量都是两维的。
equation 中的字符也可以理解为索引,就是输出张量的某个位置的值,是怎么从输入张量中得到的,比如上面矩阵乘法的输出 c 的某个点 c[i, j] 的值是通过 a[i, k] 和 b[k, j] 沿着 k 这个维度做内积得到的。
接着介绍两个基本概念,自由索引(Free indices)和求和索引(Summation indices):
-
自由索引,出现在箭头右边的索引,比如上面的例子就是 i 和 j;
-
求和索引,只出现在箭头左边的索引,表示中间计算结果需要这个维度上求和之后才能得到输出,比如上面的例子就是 k。
接着是介绍三条基本规则:
-
规则一:equation 箭头左边,在不同输入之间重复出现的索引表示,把输入张量沿着该维度做乘法操作,比如还是以上面矩阵乘法为例, "ik,kj->ij",k 在输入中重复出现,所以就是把 a 和 b 沿着 k 这个维度作相乘操作;
-
规则二:只出现在 equation 箭头左边的索引,表示中间计算结果需要在这个维度上求和,也就是上面提到的求和索引;
-
规则三:equation 箭头右边的索引顺序可以是任意的,比如上面的 "ik,kj->ij" 如果写成 "ik,kj->ji",那么就是返回输出结果的转置,用户只需要定义好索引的顺序,转置操作会在 einsum 内部完成。
特殊规则
特殊规则有两条:
-
equation 可以不写包括箭头在内的右边部分,那么在这种情况下,输出张量的维度会根据默认规则推导。就是把输入中只出现一次的索引取出来,然后按字母表顺序排列,比如上面的矩阵乘法 "ik,kj->ij" 也可以简化为 "ik,kj",根据默认规则,输出就是 "ij" 与原来一样;
-
equation 中支持 "..." 省略号,用于表示用户并不关心的索引,比如只对一个高维张量的最后两维做转置可以这么写:
a = torch.randn(2,3,5,7,9)
# i = 7, j = 9
b = torch.einsum('...ij->...ji', [a])
2
实际例子解读
接下来将展示13个具体的例子,在这些例子中会将 PyTorch einsum 与对应的 PyTorch 张量接口和 Python 简单的循环展开实现做对比,希望读者看完这些例子之后能轻松掌握 einsum 的基本用法。
实验代码github链接:
https://github.com/Ldpe2G/CodingForFun/tree/master/einsum_ex
1.提取矩阵对角线元素
import torch
import numpy as np
a = torch.arange(9).reshape(3, 3)
# i = 3
torch_ein_out = torch.einsum('ii->i', [a]).numpy()
torch_org_out = torch.diagonal(a, 0).numpy()
np_a = a.numpy()
# 循环展开实现
np_out = np.empty((3,), dtype=np.int32)
# 自由索引外循环
for i in range(0, 3):
# 求和索引内循环
# 这个例子并没有求和索引,
# 所以相当于是1
sum_result = 0
for inner in range(0, 1):
sum_result += np_a[i, i]
np_out[i] = sum_result
print("input:\n", np_a)
print("torch ein out: \n", torch_ein_out)
print("torch org out: \n", torch_org_out)
print("numpy out: \n", np_out)
print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
print("is torch_org_out == torch_ein_out ?", np.allclose(torch_ein_out, torch_org_out))
# 终端打印结果
# input:
# [[0 1 2]
# [3 4 5]
# [6 7 8]]
# torch ein out:
# [0 4 8]
# torch org out:
# [0 4 8]
# numpy out:
# [0 4 8]
# is np_out == torch_ein_out ? True
# is torch_org_out == torch_ein_out ? True
2. 矩阵转置
import torch
import numpy as np
a = torch.arange(6).reshape(2, 3)
# i = 2, j = 3
torch_ein_out = torch.einsum('ij->ji', [a]).numpy()
torch_org_out = torch.transpose(a, 0, 1).numpy()
np_a = a.numpy()
# 循环展开实现
np_out = np.empty((3, 2), dtype=np.int32)
# 自由索引外循环
for j in range(0, 3):
for i in range(0, 2):
# 求和索引内循环
# 这个例子并没有求和索引
# 所以相当于是1
sum_result = 0
for inner in range(0, 1):
sum_result += np_a[i, j]
np_out[j, i] = sum_result
print("input:\n", np_a)
print("torch ein out: \n", torch_ein_out)
print("torch org out: \n", torch_org_out)
print("numpy out: \n", np_out)
print("is np_out == torch_org_out ?", np.allclose(torch_ein_out, np_out))
print("is torch_ein_out == torch_org_out ?", np.allclose(torch_ein_out, torch_org_out))
# 终端打印结果
# input:
# [[0 1 2]
# [3 4 5]]
# torch ein out:
# [[0 3]
# [1 4]
# [2 5]]
# torch org out:
# [[0 3]
# [1 4]
# [2 5]]
# numpy out:
# [