我们知道einsum的第一个参数是字符串,它表示后面参数的运算规则。这个运算规则具体是什么样子的呢?如下所示:
规则1:不同输入如果有相同的字母,表示这个字母代表的维度会被相乘,相乘的结果会被输出。例如下面的例子,j出现了两次,那么两个输入会在j这位维度上相乘。
m=np.einsum("ij,jk->ik",a,b)
规则2:如果在输出里面,一个字母消失,则代表会在这个字母所在的维度上求和。还是上面这个例子。j在输出中消失了,那么j原先的乘积的结果会求和再输出。
m=np.einsum("ij,jk->ik",a,b)
规则3:没有进行求和的维度可以以任意顺序输出。这个容易理解,就是可以任意更换维度的顺序。
m=np.einsum("ijk->ikj",a)
我们来看看einsum具体的运算逻辑的例子。 还是用上面的例子,下面这个两个运算是等价的。i和k在输入中只出现了一次,写在循环的最外面;j在输入中重复了,要进行乘法运算,写在循环最里面;j在输出中消失了,所以乘法的结果需要加起来。
m=np.einsum("ij,jk->ik",a,b)
import numpy as np
a=np.random.rand(5,6)
b=np.random.rand(6,3)
m=np.zeros((5,3))
for i in range(5):
for k in range(3):
for j in range(6):
m[i,k]+=a[i,j]*b[j,k]
那如果einsum的计算如下又怎么理解呢?
m=np.einsum("i,k->ik",a,b)
其实和上面的例子类似,i和j在输出中没有重复,写在循环最外面;没有重复的字母所以没有额外的循环了;输出中没有消失的字母,所以没有求和操作存在。
import numpy as np
a=np.random.rand(5)
b=np.random.rand(6)
m=np.zeros((5,6))
for i in range(5):
for j in range(6):
m[i,j]=a[i]*b[j]
上面这个例子为什么最后一行有乘法,如果我们假象有一个重复的字母k存在,且这个字母所在维度的长度为1,那么上面代码经过修改,就和"ij,jk->ik"的代码完全一样了。
import numpy as np
a=np.random.rand(5,1)
b=np.random.rand(1,6)
m=np.zeros((5,6))
for i in range(5):
for k in range(6):
for j in range(1):
m[i,k]+=a[i,j]*b[j,k]
再来一个例子。这个怎么计算?
m=np.einsum("ij,kj->ik",a,b)
我们把等价的循环写出来,就比较清楚了。除了下标位置不同,代码结构和"ij,jk->ik"是一样的。
import numpy as np
a=np.random.rand(5,3)
b=np.random.rand(6,3)
m=np.zeros((5,6))
for i in range(5):
for k in range(6):
for j in range(3)
m[i,k]+=a[i,j]*b[k,j]