Einsum爱因斯坦和

Before Writing

爱因斯坦和是关于矩阵运算的通用工具,面对矩阵乘积求和运算十分便利,原理是通过矩阵下标索引进行操作的(大多时候优于我们写的乘法性能)。

这里着重介绍torch.einsum的使用和技巧,便于使用到具体的科研和工程中。

关于更多einsum的内容可以参考:TorchNumpyStackOverFlow


TORCH.EINSUM

  • 作用:对输入按着特定维度进行乘积求和

    • 假设我们有矩阵AB
      • 如果对AB进行矩阵乘法 得到一个新矩阵C
      • 然后对C矩阵的一个维度进行求和
      • 最后再交换C矩阵中的维度
    • 如果想完成上述操作,我们可以通过C = A @ (B.T),然后对其进行sumtranspose,但是如此操作不仅语句冗杂,也不是很memory-efficiently,所以einsum出现可以完美的解决上面的问题。
    A = array([[1, 1, 1],
               [2, 2, 2],
               [5, 5, 5]])
    
    B = array([[0, 1, 0],
               [1, 1, 0],
               [1, 1, 1]])
    
    # 使用einsum
    C = torch.einsum('ij, jk -> ik', A, B)
    

    在这里插入图片描述

  • 使用说明:

    • 使用einsum表达式时应当注意torch.einsum('ij,jk->ik', A, B)中下标字母应当在[a-z A-Z]当中,使用字符的顺序和矩阵的维度一一对应,比如i代表第一个矩阵的行数,j对应第一个矩阵的列数;这里相当于矩阵A的行与矩阵B的列进行内积,最后得到维度为ik的矩阵。
    • 使用表达式时应注意,用来区分一个输入还是两个输入,如果使用就说明对两个输入进行操作。
    • 使用是矩阵的维度必须是broadcastable,即维度要对应上或者是be-like。
    • 例外情况是,如果对同一输入重复下标torch.einsum('ii',A),在这种情况下,下标i必须和A的维度对应上,其实现的操作是对角线的和(trace)。
    • 对于省略号(Ellipsis)的使用:
      • 省略号能够代替字母字符作为下标,通常对一个矩阵中最多使用一次
        • 比如矩阵A有5个维度(2, 3, 4, 3, 5),如果用ab...c索引Aa代表2,b代表3,...代表4,3两个维度,c代表5这个维度。
        • 同样的如果用..ij代表A...代表2,3,4这三个维度,ij代表3,5这最后两个维度。
    • 对于空格来讲,我们随便的加入到字符下标,箭头,逗号都不会对表达式产生影响,但如果加到省略号里面就会产生影响,所以总结起来就是尽量别再einsum中使用空格。

使用示例

输入参数:

  • 字符下标(str)
  • 操作矩阵(List[Tensor])

输出参数:

  • tensor

迹trace

求对角元素的和

>>> # trace
>>> torch.einsum('ii', torch.randn(4, 4))
tensor(-1.2104)

对角元素diagonal

  • 取出对角元素

    >>> # diagonal
    >>> torch.einsum('ii->i', torch.randn(4, 4))
    tensor([-0.1034,  0.7952, -0.2433,  0.4545])
    
  • 如果面对多维矩阵,会先进行squeeze再取对角元素

    >>> X = torch.arange(36).reshape(1, 6, 1, 6)   
    >>> torch.einsum('ijij->ij', X)
    tensor([[ 0,  7, 14, 21, 28, 35]])
    

矩阵乘积product

  • 对矩阵进行乘积

    >>> # outer product
    >>> x = torch.randn(5)
    >>> y = torch.randn(4)
    >>> torch.einsum('i,j->ij', x, y)
    tensor([[ 0.1156, -0.2897, -0.3918,  0.4963],
            [-0.3744,  0.9381,  1.2685, -1.6070],
            [ 0.7208, -1.8058, -2.4419,  3.0936],
            [ 0.1713, -0.4291, -0.5802,  0.7350],
            [ 0.5704, -1.4290, -1.9323,  2.4480]])
    
  • 如果我们在训练时想保留batch_size维度可以这么操作:

    >>> # batch matrix multiplication
    >>> As = torch.randn(3, 2, 5)
    >>> Bs = torch.randn(3, 5, 4)
    >>> torch.einsum('bij,bjk->bik', As, Bs)
    tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
            [-1.6706, -0.8097, -0.8025, -2.1183]],
    
            [[ 4.2239,  0.3107, -0.5756, -0.2354],
            [-1.4558, -0.3460,  1.5087, -0.8530]],
    
            [[ 2.8153,  1.8787, -4.3839, -1.2112],
            [ 0.3728, -2.1131,  0.0921,  0.8305]]])
    
  • 也可已通过list作为下标:

    • torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
      • [..., 0, 1]表示将矩阵As最后的两个维度记为0,1。
      • [..., 1, 2]表示将矩阵Bs最后的两个维度记为1,2。
      • 最后得到维度为 [..., 0, 2]的矩阵。
      >>> As = torch.ones(3, 2, 5)
      >>> Bs = torch.ones(3, 5, 4)
      >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
      tensor([[[5., 5., 5., 5.],
               [5., 5., 5., 5.]],
      
              [[5., 5., 5., 5.],
               [5., 5., 5., 5.]],
      
              [[5., 5., 5., 5.],
               [5., 5., 5., 5.]]])
      
  • 如果想对多个矩阵做连续乘法

    >>> # equivalent to torch.nn.functional.bilinear
    >>> A = torch.randn(3, 5, 4)
    >>> l = torch.randn(2, 5)
    >>> r = torch.randn(2, 4)
    >>> torch.einsum('bn,anm,bm->ba', l, A, r)
    tensor([[-0.3430, -5.2405,  0.4494],
            [ 0.3311,  5.5201, -3.0356]])
    

维度变换permute

只需要对->两端维度改变位置即可:

>>> # batch permute
>>> A = torch.randn(2, 3, 4, 5)
>>> torch.einsum('...ij->...ji', A).shape
torch.Size([2, 3, 5, 4])


向量进行外积

torch.einsum('i,j->ij', x[0], x[0])


总结

Einsum是一个随心所欲的方法,使得矩阵乘法运算变得更易于描述,因此我们在做复杂矩阵乘积时应当首先考虑使用。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值