Before Writing
爱因斯坦和是关于矩阵运算的通用工具,面对矩阵乘积求和运算十分便利,原理是通过矩阵下标索引进行操作的(大多时候优于我们写的乘法性能)。
这里着重介绍torch.einsum
的使用和技巧,便于使用到具体的科研和工程中。
关于更多einsum的内容可以参考:Torch,Numpy,StackOverFlow
TORCH.EINSUM
-
作用:对输入按着特定维度进行乘积求和
- 假设我们有矩阵
A
和B
- 如果对
AB
进行矩阵乘法 得到一个新矩阵C
- 然后对
C
矩阵的一个维度进行求和 - 最后再交换
C
矩阵中的维度
- 如果对
- 如果想完成上述操作,我们可以通过
C = A @ (B.T)
,然后对其进行sum
和transpose
,但是如此操作不仅语句冗杂,也不是很memory-efficiently,所以einsum出现可以完美的解决上面的问题。
A = array([[1, 1, 1], [2, 2, 2], [5, 5, 5]]) B = array([[0, 1, 0], [1, 1, 0], [1, 1, 1]]) # 使用einsum C = torch.einsum('ij, jk -> ik', A, B)
- 假设我们有矩阵
-
使用说明:
- 使用einsum表达式时应当注意
torch.einsum('ij,jk->ik', A, B)
中下标字母应当在[a-z A-Z]当中,使用字符的顺序和矩阵的维度一一对应,比如i
代表第一个矩阵的行数,j
对应第一个矩阵的列数;这里相当于矩阵A
的行与矩阵B
的列进行内积,最后得到维度为ik
的矩阵。 - 使用表达式时应注意
,
用来区分一个输入还是两个输入,如果使用就说明对两个输入进行操作。 - 使用是矩阵的维度必须是broadcastable,即维度要对应上或者是be-like。
- 例外情况是,如果对同一输入重复下标
torch.einsum('ii',A)
,在这种情况下,下标i
必须和A
的维度对应上,其实现的操作是对角线的和(trace)。 - 对于省略号(Ellipsis)的使用:
- 省略号能够代替字母字符作为下标,通常对一个矩阵中最多使用一次
- 比如矩阵
A
有5个维度(2, 3, 4, 3, 5)
,如果用ab...c
索引A
则a
代表2,b
代表3,...
代表4,3两个维度,c
代表5这个维度。 - 同样的如果用
..ij
代表A
,...
代表2,3,4这三个维度,ij
代表3,5这最后两个维度。
- 比如矩阵
- 省略号能够代替字母字符作为下标,通常对一个矩阵中最多使用一次
- 对于空格来讲,我们随便的加入到字符下标,箭头,逗号都不会对表达式产生影响,但如果加到省略号里面就会产生影响,所以总结起来就是尽量别再einsum中使用空格。
- 使用einsum表达式时应当注意
使用示例
输入参数:
- 字符下标(str)
- 操作矩阵(List[Tensor])
输出参数:
- tensor
迹trace
求对角元素的和
>>> # trace
>>> torch.einsum('ii', torch.randn(4, 4))
tensor(-1.2104)
对角元素diagonal
-
取出对角元素
>>> # diagonal >>> torch.einsum('ii->i', torch.randn(4, 4)) tensor([-0.1034, 0.7952, -0.2433, 0.4545])
-
如果面对多维矩阵,会先进行squeeze再取对角元素
>>> X = torch.arange(36).reshape(1, 6, 1, 6) >>> torch.einsum('ijij->ij', X) tensor([[ 0, 7, 14, 21, 28, 35]])
矩阵乘积product
-
对矩阵进行乘积
>>> # outer product >>> x = torch.randn(5) >>> y = torch.randn(4) >>> torch.einsum('i,j->ij', x, y) tensor([[ 0.1156, -0.2897, -0.3918, 0.4963], [-0.3744, 0.9381, 1.2685, -1.6070], [ 0.7208, -1.8058, -2.4419, 3.0936], [ 0.1713, -0.4291, -0.5802, 0.7350], [ 0.5704, -1.4290, -1.9323, 2.4480]])
-
如果我们在训练时想保留batch_size维度可以这么操作:
>>> # batch matrix multiplication >>> As = torch.randn(3, 2, 5) >>> Bs = torch.randn(3, 5, 4) >>> torch.einsum('bij,bjk->bik', As, Bs) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]])
-
也可已通过list作为下标:
torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
中[..., 0, 1]
表示将矩阵As
最后的两个维度记为0,1。[..., 1, 2]
表示将矩阵Bs
最后的两个维度记为1,2。- 最后得到维度为
[..., 0, 2]
的矩阵。
>>> As = torch.ones(3, 2, 5) >>> Bs = torch.ones(3, 5, 4) >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2]) tensor([[[5., 5., 5., 5.], [5., 5., 5., 5.]], [[5., 5., 5., 5.], [5., 5., 5., 5.]], [[5., 5., 5., 5.], [5., 5., 5., 5.]]])
-
如果想对多个矩阵做连续乘法
>>> # equivalent to torch.nn.functional.bilinear >>> A = torch.randn(3, 5, 4) >>> l = torch.randn(2, 5) >>> r = torch.randn(2, 4) >>> torch.einsum('bn,anm,bm->ba', l, A, r) tensor([[-0.3430, -5.2405, 0.4494], [ 0.3311, 5.5201, -3.0356]])
维度变换permute
只需要对->
两端维度改变位置即可:
>>> # batch permute
>>> A = torch.randn(2, 3, 4, 5)
>>> torch.einsum('...ij->...ji', A).shape
torch.Size([2, 3, 5, 4])
向量进行外积
torch.einsum('i,j->ij', x[0], x[0])
总结
Einsum是一个随心所欲的方法,使得矩阵乘法运算变得更易于描述,因此我们在做复杂矩阵乘积时应当首先考虑使用。