einsum(equation, *tensors)函数用法【附代码】

einsum(equation, *tensors)函数对输入的一个或多个张量(*tensors)作定义为equation的运算 

通常来说,equation由输入的张量(*tensors)的shape来定义的,类似于einops模块中的rearrange, reduce, repeat但规则不同: 
1. 一般来说,eqution的形式是'i j -> j i'(矩阵转置例子,这里的i, j对应的是张量shape的对应维度) # 如shape:(2, 3);可知i=2, j=3. 
2. 当箭头'->'右侧为空,默认为箭头左侧只出现一次的值,如'ij,jk -> '等价于'ij, jk -> ik'(矩阵乘法) 
3. 对于箭头左侧重复出现的值(指向对应张量shape的某个维度),代表沿着该维度作乘法操作,如'ij, jk -> ik'(沿着'j'维度作乘法,矩阵乘法) 
4. 在规则2的基础上,我们可以任意调整箭头右侧维度的顺序进而调整输出,如'ij, jk -> ki'(就是先做矩阵乘法,ik再转置为ki)
5. 若对多个张量进行操作,需在箭头左侧用逗号隔开表示指向不同张量的shape 

​​​​​​​import einsum
import torch


x = [
    [
        [1, 2, 3, 4],
        [5, 6, 7, 8],
        [9, 10, 11, 12]
    ],
    [
        [13, 14, 15, 16],
        [17, 18, 19, 20],
        [21, 22, 23, 24]
    ]
]
# 将x转化为张量,数据类型为浮点数
x = torch.tensor(x).float()
print(x.shape)  # torch.Size([2, 3, 4])

# einsum()函数对张量进行降维
x1_1 = torch.einsum('b h w -> h w', x)  
# 这里等效于reduce('b h w -> h w', 'max')函数,默认采用最大池化
print(x1_1)
print(x1_1.shape)
"""
tensor([[14., 16., 18., 20.],
        [22., 24., 26., 28.],
        [30., 32., 34., 36.]])
torch.Size([3, 4])
"""

# einsum()函数交换张量shape的特定维度(支持一次性交换多个维度)
x1_2 = torch.einsum('i j k -> k i j', x)
print(x1_2)
print(x1_2.shape)
"""
tensor([[[ 1.,  5.,  9.],
         [13., 17., 21.]],
        [[ 2.,  6., 10.],
         [14., 18., 22.]],
        [[ 3.,  7., 11.],
         [15., 19., 23.]],
        [[ 4.,  8., 12.],
         [16., 20., 24.]]])
torch.Size([4, 2, 3])
"""

# torch.einsum()不支持对维度通过数值乘除放缩,
# 因为equation中' -> '箭头两侧元素必须都是在{a/A-z/Z}即所有大小写字母组成的集合中
x1_3 = torch.einsum('...h w -> ...w h', x)  # 交换张量x.shape的后两个维度
print(x1_3)
print(x1_3.shape)
"""
tensor([[[ 1.,  5.,  9.],
         [ 2.,  6., 10.],
         [ 3.,  7., 11.],
         [ 4.,  8., 12.]],
        [[13., 17., 21.],
         [14., 18., 22.],
         [15., 19., 23.],
         [16., 20., 24.]]])
torch.Size([2, 4, 3])
"""

# torch.einsum()函数还支持矩阵乘法(本质是broadcasting广播规则),矩阵trace等运算
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(a.shape)  # torch.Size([2, 3])
b = torch.tensor([[7, 8], [9, 10], [11, 12]])
print(b.shape)  # torch.Size([3, 2])

x2_1 = torch.einsum('ij, jk -> ik', a, b)  # 矩阵乘法
print(x2_1)
print(x2_1.shape)
"""
tensor([[ 58,  64],
        [139, 154]])
torch.Size([2, 2])
"""
c = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
x2_2 = torch.einsum('ii -> ', c)  # 矩阵trace的运算
print(x2_2)  # tensor(15);就是一个数
print(x2_2.shape)  # torch.Size([]),有点odd不太懂
print(x2_2.shape is None)  # False

# torch.einsum()实现两个向量直接的点积
d = torch.tensor([1, 2, 3])
e = torch.tensor([4, 5, 6])

x3_1 = torch.einsum('i, i -> ', d, e)  # 向量点积
print(x3_1)  # tensor(32)
print(x3_1.shape)  # torch.Size([])

                
  • 8
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值