参考:https://www.cnblogs.com/mengnan/p/10319701.html
爱因斯坦简记法,能简洁表示各种矩阵向量的操作,例如矩阵转置、乘法、求和等等,pytorch中调用API为torch.einsum,第一个参数是字符串,表示对张量的操作,定义不同字符串表示不同操作
例子
矩阵转置
字符串为ij->ji
,表示将张量中的元素ij(第i行第j列)变成元素ji
a = torch.tensor([
[1, 2],
[2, 3],
[4, 5],
[2, 6]
])
b = torch.einsum('ij->ji', a)
print(b)
>>> tensor([[1, 2, 4, 2],
[2, 3, 5, 6]])
矩阵点乘(哈达玛积)
将矩阵对应位置的元素相乘,字符串ij,ij->ij
表示将两个矩阵的ij元素相乘得到元素ij,貌似只能是相乘不能是相加?
a = torch.tensor([
[1, 2],
[2, 3],
[4, 5],
[2, 6]
])
b = torch.tensor([
[2, 6],
[2, 4],
[3, 1],
[2, 2]
])
c = torch.einsum('ij,ij->ij', a, b)
print(c)
>>> tensor([[ 2, 12],
[ 4, 12],
[12, 5],
[ 4, 12]])
矩阵乘法
将矩阵
a
∈
R
M
×
N
a\in R^{M\times N}
a∈RM×N和矩阵
b
∈
R
N
×
S
b\in R^{N\times S}
b∈RN×S相乘得到矩阵
c
∈
R
M
×
S
c\in R^{M\times S}
c∈RM×S,字符串ik,kj->ij
表示将矩阵a的ik元素和矩阵b的kj元素相乘,k是遍历变量,遍历所有k,累加ik*kj累加得到ij,即矩阵乘法
a = torch.tensor([
[1, 2, 4],
[2, 3, 5],
])
b = torch.tensor([
[2, 6],
[2, 4],
[3, 1],
])
c = torch.einsum('ik,kj->ij', a, b)
print(c)
>>> tensor([[18, 18],
[25, 29]])