pytorch中einsum用法总结

爱因斯坦求和约定

爱因斯坦求和约定(einsum)提供了一套既简洁又优雅的规则,可实现包括但不限于:向量内积,向量外积,矩阵乘法,转置和张量收缩(tensor contraction)等张量操作,熟练运用 einsum 可以很方便的实现复杂的张量操作,而且不容易出错。

自由索引(Free indices)和求和索引(Summation indices):

  • 自由索引:出现在箭头右边的索引
  • 求和索引:只出现在箭头左边的索引,表示中间计算结果需要这个维度上求和之后才能得到输出

三条基本规则

  • 规则一,equation 箭头左边,在不同输入之间重复出现的索引表示,把输入张量沿着该维度做乘法操作
  • 规则二,只出现在 equation 箭头左边的索引,表示中间计算结果需要在这个维度上求和;
  • 规则三,equation 箭头右边的索引顺序可以是任意的

使用方法

提取元素

提取矩阵对角线元素
import torch
a = torch.arange(9).reshape(3, 3)
x = torch.einsum('ii->i', [a])
print(a)
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
        
print(x)
tensor([0, 4, 8])

转置

矩阵转置
import torch
a = torch.arange(6).reshape(2, 3)
x = torch.einsum('ij->ji', [a])
print(a)
tensor([[0, 1, 2],
        [3, 4, 5]])
        
print(x)
tensor([[0, 3],
        [1, 4],
        [2, 5]])
permute 高维张量转置
import torch
a = torch.randn(2,3,5,7,9)
# i = 7, j = 9
x = torch.einsum('...ij->...ji', [a])
print(a.shape)
torch.Size([2, 3, 5, 7, 9])

print(x.shape)
torch.Size([2, 3, 5, 9, 7])

求和

矩阵求和
import torch
a = torch.arange(6).reshape(2, 3)
# i = 2, j = 3
x = torch.einsum('ij->', [a])
print(a)
tensor([[0, 1, 2],
        [3, 4, 5]])
     
print(x)
tensor(15)
矩阵按行求和
import torch
a = torch.arange(6).reshape(2, 3)
# i = 2, j = 3
x = torch.einsum('ij->i', [a])
print(a)
tensor([[0, 1, 2],
        [3, 4, 5]])
     
print(x)
tensor([ 3, 12])
矩阵按列求和
import torch
a = torch.arange(6).reshape(2, 3)
# i = 2, j = 3
x = torch.einsum('ij->j', [a])
print(a)
tensor([[0, 1, 2],
        [3, 4, 5]])
     
print(x)
tensor([3, 5, 7])

乘积

矩阵向量乘法
import torch
a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
# i = 2, k = 3
x = torch.einsum('ik,k->i', [a, b])
# 等价形式,可以省略箭头和输出
x2 = torch.einsum('ik,k', [a, b])
print(a)
tensor([[0, 1, 2],
        [3, 4, 5]])
        
print(b)
tensor([0, 1, 2])

print(x)
tensor([ 5, 14])

print(x2)
tensor([ 5, 14])
矩阵乘法
import torch
a = torch.arange(6).reshape(2, 3)
b = torch.arange(15).reshape(3, 5)
# i = 2, k = 3, j = 5
x = torch.einsum('ik,kj->ij', [a, b])
# 等价形式,可以省略箭头和输出
x2 = torch.einsum('ik,kj', [a, b])
print(a)
tensor([[0, 1, 2],
        [3, 4, 5]])
        
print(b)
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14]])

print(x)
tensor([[ 25,  28,  31,  34,  37],
        [ 70,  82,  94, 106, 118]])

print(x2)
tensor([[ 25,  28,  31,  34,  37],
        [ 70,  82,  94, 106, 118]])
向量内积
import torch
a = torch.arange(3)
b = torch.arange(3, 6) # [3, 4, 5]
# i = 3
x = torch.einsum('i,i->', [a, b])
# 等价形式,可以省略箭头和输出
x2 = torch.einsum('i,i', [a, b])
print(a)
tensor([0, 1, 2])

print(b)
tensor([3, 4, 5])

print(x)
tensor(14)

print(x2)
tensor(14)
矩阵元素对应相乘并求和
import torch
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
# i = 2, j = 3
x = torch.einsum('ij,ij->', [a, b])
# 等价形式,可以省略箭头和输出
x2 = torch.einsum('ij,ij', [a, b])
print(a)
tensor([[0, 1, 2],
        [3, 4, 5]])
       

print(b)
tensor([[ 6,  7,  8],
        [ 9, 10, 11]])
       

print(x)
tensor(145)

print(x2)
tensor(145)
向量外积
import torch
a = torch.arange(3)
b = torch.arange(3,7)  # [3, 4, 5, 6]
# i = 3, j = 4
x = torch.einsum('i,j->ij', [a, b])
# 等价形式,可以省略箭头和输出
x2 = torch.einsum('i,j', [a, b])
print(a)
tensor([0, 1, 2])

print(b)
tensor([3, 4, 5, 6])

print(x)
tensor([[ 0,  0,  0,  0],
        [ 3,  4,  5,  6],
        [ 6,  8, 10, 12]])

print(x2)
tensor([[ 0,  0,  0,  0],
        [ 3,  4,  5,  6],
        [ 6,  8, 10, 12]])
batch 矩阵乘法
import torch
a = torch.randn(2,3,5)
b = torch.randn(2,5,4)
# i = 2, j = 3, k = 5, l = 4
x = torch.einsum('ijk,ikl->ijl', [a, b])
print(a.shape)
torch.Size([2, 3, 5])

print(b.shape)
torch.Size([2, 5, 4])

print(x.shape)
torch.Size([2, 3, 4])

变换

张量收缩(tensor contraction)
import torch
a = torch.randn(2,3,5,7)
b = torch.randn(11,13,3,17,5)
# p = 2, q = 3, r = 5, s = 7
# t = 11, u = 13, v = 17, r = 5
x = torch.einsum('pqrs,tuqvr->pstuv', [a, b])
print(a.shape)
torch.Size([2, 3, 5, 7])

print(b.shape)
torch.Size([11, 13, 3, 17, 5])

print(x.shape)
torch.Size([2, 7, 11, 13, 17])
二次变换(bilinear transformation)
import torch
a = torch.randn(2,3)
b = torch.randn(5,3,7)
c = torch.randn(2,7)
# i = 2, k = 3, j = 5, l = 7
x = torch.einsum('ik,jkl,il->ij', [a, b, c])
print(a.shape)
torch.Size([2, 3])

print(b.shape)
torch.Size([5, 3, 7])

print(c.shape)
torch.Size([2, 7])

print(x.shape)
torch.Size([2, 5])

其他操作

  • 向量操作:A、B均为向量
    在这里插入图片描述
  • 向量操作:A、B均为2D矩阵
    在这里插入图片描述
    注意:
  • einsum求和时不提升数据类型,如果使用的数据类型范围有限,可能会得到意外的错误:
a = np.ones(300, dtype=np.int8)
print(np.sum(a)) # correct result
print(np.einsum('i->', a)) # produces incorrect result
300
44
  • einsum 在implicit mode可能不会按预期的顺序排列轴
M = np.arange(24).reshape(2,3,4)
print(np.einsum('kij', M).shape) # 不是预期
print(np.einsum('ijk->kij', M).shape) #符合预期
(3, 4, 2)
(4, 2, 3)

np.einsum(‘kij’, M) 实际上等价于 np.einsum(‘kij->ijk’, M),因为 implicit mode 下,einsum会认为根据输入标记,默认按照字母表顺序排序,作为输出标记。

参考:
https://zhuanlan.zhihu.com/p/101157166
https://zhuanlan.zhihu.com/p/361209187

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值