一文轻松掌握深度学习框架中的einsum

60f2071ef8dbc616ffd6859b4bd04270.png

导语:本文主要介绍了如何理解 PyTorch 中的爱因斯坦求和 (einsum) ,并结合实际例子讲解和 PyTorch C++实现代码解读,希望读者看完本文后掌握 einsum 的基本用法。

撰文|梁德澎

原文首发于公众号GiantpandaCV

 

1

爱因斯坦求和约定

爱因斯坦求和约定(einsum)提供了一套既简洁又优雅的规则,可实现包括但不限于:向量内积,向量外积,矩阵乘法,转置和张量收缩(tensor contraction)等张量操作,熟练运用 einsum 可以很方便地实现复杂的张量操作,而且不容易出错。

三条基本规则

首先看下 einsum 实现矩阵乘法的例子:

a = torch.rand(2,3)
b = torch.rand(3,4)
c = torch.einsum("ik,kj->ij", [a, b])
# 等价操作 torch.mm(a, b)

其中需要重点关注的是 einsum 的第一个参数 "ik,kj->ij",该字符串(下文以 equation 表示)表示了输入和输出张量的维度。equation 中的箭头左边表示输入张量,以逗号分割每个输入张量,箭头右边则表示输出张量。表示维度的字符只能是26个英文字母 'a' - 'z'。

而 einsum 的第二个参数表示实际的输入张量列表,其数量要与 equation 中的输入数量对应。同时对应每个张量的 子 equation 的字符个数要与张量的真实维度对应,比如 "ik,kj->ij" 表示输入和输出张量都是两维的。

equation 中的字符也可以理解为索引,就是输出张量的某个位置的值,是怎么从输入张量中得到的,比如上面矩阵乘法的输出 c 的某个点 c[i, j] 的值是通过 a[i, k] 和 b[k, j] 沿着 k 这个维度做内积得到的。

接着介绍两个基本概念,自由索引(Free indices)和求和索引(Summation indices):

  • 自由索引,出现在箭头右边的索引,比如上面的例子就是 i 和 j;

  • 求和索引,只出现在箭头左边的索引,表示中间计算结果需要这个维度上求和之后才能得到输出,比如上面的例子就是 k。

接着是介绍三条基本规则:

  • 规则一:equation 箭头左边,在不同输入之间重复出现的索引表示,把输入张量沿着该维度做乘法操作,比如还是以上面矩阵乘法为例, "ik,kj->ij",k 在输入中重复出现,所以就是把 a 和 b 沿着 k 这个维度作相乘操作;

  • 规则二:只出现在 equation 箭头左边的索引,表示中间计算结果需要在这个维度上求和,也就是上面提到的求和索引;

  • 规则三:equation 箭头右边的索引顺序可以是任意的,比如上面的 "ik,kj->ij" 如果写成 "ik,kj->ji",那么就是返回输出结果的转置,用户只需要定义好索引的顺序,转置操作会在 einsum 内部完成。

特殊规则

特殊规则有两条:

  • equation 可以不写包括箭头在内的右边部分,那么在这种情况下,输出张量的维度会根据默认规则推导。就是把输入中只出现一次的索引取出来,然后按字母表顺序排列,比如上面的矩阵乘法 "ik,kj->ij" 也可以简化为 "ik,kj",根据默认规则,输出就是 "ij" 与原来一样;

  • equation 中支持 "..." 省略号,用于表示用户并不关心的索引,比如只对一个高维张量的最后两维做转置可以这么写:

a = torch.randn(2,3,5,7,9)
# i = 7, j = 9
b = torch.einsum('...ij->...ji', [a])

2

实际例子解读

接下来将展示13个具体的例子,在这些例子中会将 PyTorch einsum 与对应的 PyTorch 张量接口和 Python 简单的循环展开实现做对比,希望读者看完这些例子之后能轻松掌握 einsum 的基本用法。

实验代码github链接:

https://github.com/Ldpe2G/CodingForFun/tree/master/einsum_ex

1.提取矩阵对角线元素

import torch
import numpy as np

a = torch.arange(9).reshape(3, 3)
# i = 3
torch_ein_out = torch.einsum('ii->i', [a]).numpy()
torch_org_out = torch.diagonal(a, 0).numpy()

np_a = a.numpy()
# 循环展开实现
np_out = np.empty((3,), dtype=np.int32)
# 自由索引外循环
for i in range(0, 3):
    # 求和索引内循环
    # 这个例子并没有求和索引,
    # 所以相当于是1
    sum_result = 0
    for inner in range(0, 1):
        sum_result += np_a[i, i]
    np_out[i] = sum_result

print("input:\n", np_a)
print("torch ein out: \n", torch_ein_out)
print("torch org out: \n", torch_org_out)
print("numpy out: \n", np_out)
print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
print("is torch_org_out == torch_ein_out ?", np.allclose(torch_ein_out, torch_org_out))

# 终端打印结果
# input:
#  [[0 1 2]
#  [3 4 5]
#  [6 7 8]]
# torch ein out:
#  [0 4 8]
# torch org out:
#  [0 4 8]
# numpy out:
#  [0 4 8]
# is np_out == torch_ein_out ? True
# is torch_org_out == torch_ein_out ? True

2. 矩阵转置

import torch
import numpy as np

a = torch.arange(6).reshape(2, 3)
# i = 2, j = 3
torch_ein_out = torch.einsum('ij->ji', [a]).numpy()
torch_org_out = torch.transpose(a, 0, 1).numpy()

np_a = a.numpy()
# 循环展开实现
np_out = np.empty((3, 2), dtype=np.int32)
# 自由索引外循环
for j in range(0, 3):
    for i in range(0, 2):
        # 求和索引内循环
        # 这个例子并没有求和索引
        # 所以相当于是1
        sum_result = 0
        for inner in range(0, 1):
            sum_result += np_a[i, j]
        np_out[j, i] = sum_result

print("input:\n", np_a)
print("torch ein out: \n", torch_ein_out)
print("torch org out: \n", torch_org_out)
print("numpy out: \n", np_out)
print("is np_out == torch_org_out ?", np.allclose(torch_ein_out, np_out))
print("is torch_ein_out == torch_org_out ?", np.allclose(torch_ein_out, torch_org_out))

# 终端打印结果
# input:
#  [[0 1 2]
#  [3 4 5]]
# torch ein out:
#  [[0 3]
#  [1 4]
#  [2 5]]
# torch org out:
#  [[0 3]
#  [1 4]
#  [2 5]]
# numpy out:
#  [
  • 5
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值