einsum初探

Einsum 是干嘛的?

使用爱因斯坦求和约定,可以以简单的方式表示许多常见的多维线性代数数组运算。举个栗子:给定两个矩阵A和B,我们想对它们做一些操作,比如 multiply、sum或者transpose。虽然numpy里面有可以直接使用的接口,能够实现这些功能,但是使用enisum可以做的更快、更节省空间。比如:

A = np.array([0, 1, 2])
B = np.array([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]])

我们想计算A和B的element-wise乘积,然后按行求和。如果不使用einsum接口,需要先对A做reshape到和B一样的形状,创建一个临时的数组A[:, np.newaxis],然后在做乘积并按行求和:

(A[:, np.newaxis] * B).sum(axis=1)
array([ 0, 22, 76])

使用einsum接口,不需要创建reshape后的临时数组,只是简单地对行中的乘积求和,这样会加速三倍:

np.einsum('i,ij->i', A, B)
array([ 0, 22, 76])

如何使用 einsum

使用einsum的关键是,正确地labelling(标记)输入数组和输出数组的axes(轴)。我们可以使用字符串(比如:ijk,这种表示方式更常用)或者一个整数列表(比如:[0,1])来标记axes。 再来举个栗子:为了实现矩阵乘,我们可以这么写

np.einsum('ij,jk->ik', A, B)

字符串'ij,jk->ik'可以根据'->'的位置来切分,左边的部分('ij,jk')标记了输入的axes,右边的('ik')标记了输出的axes。输入标记又根据','的位置进行切分,'ij'标记了第一个输入A的axes,'jk'标记了第二个输入B的axes。'ij'、'jk'的字符长度都是2,对应着A和B为2D数组,'ik'的长度也为2,因此输出也是2D数组。

给定输入:

A = np.array([[1, 1, 1],
              [2, 2, 2],
              [5, 5, 5]])
B = np.array([[0, 1, 0],
              [1, 1, 0],
              [1, 1, 1]])
np.einsum('ij,jk->ik', A, B)可以看作是:

  • 在输入数组的标记之间,重复字母表示沿这些轴的值将相乘,这些乘积构成输出数组的值。比如图中沿着j轴做乘积。
  • 从输出标记中省略的字母表示沿该轴的值将被求和。比如图中的输出没有包含j轴,因此沿着j轴求和得到了输出数组中的每一项。
    • 如果输出的标记是'ijk',那么会得到一个 3x3x3 的矩阵。
      • 输出标记是'ik'的时候,并不会创建中间的 3x3x3 的矩阵,而是直接将总和累加到2D数组中。

    • 如果输出的标记是空,那么输出整个矩阵的和。
  • 我们可以按任意顺序返回不求和的轴。

我们将不指定'->'和输出标记称为 explicit mode。 如果不指定'->'和输出标记,numpy会将输入标记中只出现一次的标记按照字母表顺序,作为输出标记(也就是 implicit mode,后面会详细介绍)。

'ij,jk->ik' 等价于 'ij,jk'

在explicit mode中,我们可以指定输出标记的顺序,比如:'ij,jk->ki'表示对矩阵乘做转置。

Einsum 中的常用

对应的 einsum调用方式:

  • 向量操作:A、B均为向量

  • 向量操作:A、B均为2D矩阵

注意

  • einsum求和时不提升数据类型,如果使用的数据类型范围有限,可能会得到意外的错误:
a = np.ones(300, dtype=np.int8)
print(np.sum(a)) # correct result
print(np.einsum('i->', a)) # produces incorrect result
300
44
  • einsum 在implicit mode可能不会按预期的顺序排列轴
M = np.arange(24).reshape(2,3,4)
print(np.einsum('kij', M).shape) # 不是预期
print(np.einsum('ijk->kij', M).shape) #符合预期
(3, 4, 2)
(4, 2, 3)
np.einsum('kij', M) 实际上等价于 np.einsum('kij->ijk', M),因为 implicit mode 下,einsum会认为根据输入标记,按照字母表顺序排序,作为输出标记。
  • 最后,einsum 也不总是numpy中的最快的选择。
dot和inner函数之类的功能通常会链接到BLAS库方法,性能可能胜过einsum。 还有tensordot函数。 在多个输入数组上进行操作时,einsum似乎很慢。

看到这里就基本满足常用的要求啦,如果想深入了解大把细节,可以越过华丽丽的分割线,勇往直前!

------------------------------ 华丽丽的分割线 ------------------------------

numpy 中的 einsum

import numpy as np
# 给定两个向量,下面要用到
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
# 给定两个矩阵,下面要用到
A = np.array([[1, 2], [3, 4]])
B = np.array([[5, 6], [7, 8]])

接口定义:

numpy.einsum(subscripts, *operands, out=None, dtype=None, order='K', casting='safe', optimize=False)

参数:

xxx表示可选参数中的默认值)

  • subscripts : str,指定求和的操作,可以是多个,用逗号分开。除非包含显式指示符“->”以及精确输出形式的下标标签,否则将执行隐式(经典的爱因斯坦求和)计算。
  • operands : list of array_like,输入。
  • out : ndarray, optional,指定输出。
  • dtype : {data-type, None}, optional,可以指定运算的数据类型,可能需要用户提供类型转换接口。默认为None。
  • order : {‘C’, ‘F’, ‘A’, ‘K’}, optional,控制输出的内存布局。'C'=contiguous;'F'=Fortran contiguous;'A'表示输入为'F'时输出为'F',否则输出为'C';'K'表示输出的layout应该尽可能和输入一致。
  • casting : {‘no’, ‘equiv’, ‘safe’, ‘same_kind’, ‘unsafe’}, optional,指定可能发生的数据类型转换,不推荐使用'unsafe'。
    • ‘no’ 表示不做数据类型转换。
    • ‘equiv’ 表示仅允许字节顺序更改。
    • ‘safe’ 表示只允许保留值的强制类型转换。
    • ‘same_kind’ 表示仅允许安全类型转换或同一类型(例如float64到float32)内的类型转换。
    • ‘unsafe’ 表示可以进行任何数据转换。
  • optimize : {False, True, ‘greedy’, ‘optimal’}, optional,控制是否进行中间优化。默认False,不做优化;设为True则使用greedy算法。还接受np.einsum_path函数的提供的列表。

Returns:

  • output : ndarray

在implicit模式下,选择的下标很重要,因为输出按字母顺序重新排序。例如,在二维矩阵中:

np.einsum('ij,jh', A, B) # 返回的是矩阵乘的转置,因为'h'本来应该是在'i'的后面,但是这里反序了
array([[19, 43],
       [22, 50]])

相比,在explicit模式下

np.einsum('ij,jh->ih', A, B) # 指定了输出下标标签的顺序,因此效果等价于矩阵乘法 np.matmul(A,B)
                            # 相比之下,implicit 模式的 np.einsum('ij,jh', A, B) 效果等价于矩阵乘的转置
array([[19, 43],
       [22, 50]])

einsum 默认不会 broadcast,需要指定省略号(...)来启用。默认的NumPy样式广播是通过在每项的左侧添加省略号。

另一种使用 enisum 的方式是:einsum(op0, sublist0, op1, sublist1, ..., [sublistout])。如果没有指定输出格式,将以 implicit 模式计算,否则将在 explicit 模式执行。

np.einsum(A, [0,0])
5

optimize参数优化 einsum 表达式的收缩顺序,对于具有三个或更多操作数的收缩,这可以大大增加计算效率,但需要在计算过程中增加内存占用量。

来一些对比:

  1. 求矩阵的迹:
print(np.einsum('ii', A)) # implicit mode
print(np.einsum(A, [0,0]))# einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 implicit mode
print(np.trace(A))
5
5
5
  1. 矩阵的对角元素
print(np.einsum('ii->i', A)) # explicit mode
print(np.einsum(A, [0,0], [0]))# einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 explicit mode
print(np.diag(A))
[1 4]
[1 4]
[1 4]
  1. 对指定维度求和
print(np.einsum('ij->i', A)) # explicit mode
print(np.einsum(A, [0,1], [0]))# einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 explicit mode
print(np.sum(A, axis=1))
[3 7]
[3 7]
[3 7]

对于高维的数组,可以使用省略号来对指定轴求和

print(np.einsum('...j->...', A)) # explicit mode
print(np.einsum(A, [Ellipsis,1], [Ellipsis])) # einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 explicit mode
[3 7]
[3 7]
  1. 计算矩阵转置,或根据指定的 axes 调整矩阵
print(np.einsum('ji', A)) # implicit mode
print(np.einsum('ij->ji', A))# explicit mode
print(np.einsum(A, [1,0])) # einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 implicit mode
print(np.einsum(A, [1,0], [0,1])) # einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 explicit mode
print(np.transpose(A, axes=[1,0]))
[[1 3]
 [2 4]]
[[1 3]
 [2 4]]
[[1 3]
 [2 4]]
[[1 3]
 [2 4]]
[[1 3]
 [2 4]]
  1. 向量内积
print(np.einsum('i,i', a, a)) # implicit mode
print(np.einsum(a, [0], b, [0]))# einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 implicit mode
print(np.inner(a, a))
14
32
14
  1. 矩阵向量乘积
k = np.array([1,2])
print(np.einsum('ij,j', A, k)) # implicit mode
print(np.einsum(A, [0,1], k, [1]))# einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 implicit mode
print(np.einsum('ij,j->i', A, k)) # explicit mode
print(np.einsum('...j,j->...', A, k)) # explicit mode with broadcast
print(np.dot(A, k))
[ 5 11]
[ 5 11]
[ 5 11]
[ 5 11]
[ 5 11]
  1. Broadcasting
print(np.einsum('...,...', 3, A)) # implicit mode
print(np.einsum(',ij', 3, A))# explicit mode
print(np.einsum(3, [Ellipsis], A, [Ellipsis]))
print(np.multiply(3, A))
[[ 3  6]
 [ 9 12]]
[[ 3  6]
 [ 9 12]]
[[ 3  6]
 [ 9 12]]
[[ 3  6]
 [ 9 12]]
  1. Tensor Contraction
m = np.arange(60.).reshape(3,4,5)
n = np.arange(24.).reshape(4,3,2)
print(np.einsum('ijk,jil->kl', m, n)) # explicit mode
print(np.einsum(a, [0,1,2], b, [1,0,3], [2,3])) # einsum(op0, sublist0, op1, sublist1, ..., [sublistout])方式 explicit mode
print(np.tensordot(a,b, axes=([1,0],[0,1]))) #
[[4400. 4730.]
 [4532. 4874.]
 [4664. 5018.]
 [4796. 5162.]
 [4928. 5306.]]
[[4400. 4730.]
 [4532. 4874.]
 [4664. 5018.]
 [4796. 5162.]
 [4928. 5306.]]
[[4400. 4730.]
 [4532. 4874.]
 [4664. 5018.]
 [4796. 5162.]
 [4928. 5306.]]
  1. 链式数组操作: For more complicated contractions, speed ups might be achieved by repeatedly computing a ‘greedy’ path or pre-computing the ‘optimal’ path and repeatedly applying it, using an einsum_path insertion. Performance improvements can be particularly significant with larger arrays. 待补充
     

TODO: Tensorflow 中的 einsum; Pytorch 中的 einsum

参考:

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值