einsum方法详解(爱因斯坦求和)

本文详细介绍了Python科学计算库NumPy和PyTorch中的einsum函数,它是基于爱因斯坦求和记法的一种高效矩阵运算工具。einsum能实现各种复杂的矩阵计算,如点积、转置、求和等。通过实例展示了einsum的使用方法,包括计算矩阵元素之和、矩阵的迹、对角线元素、向量相乘、矩阵点积、矩阵对应元素相乘及双线性运算等,帮助读者掌握这一强大的功能。
摘要由CSDN通过智能技术生成

einsum方法详解(爱因斯坦求和)

einsum是pytorch、numpy中一个十分优雅的方法,如果利用得当,可完全代替所有其他的矩阵计算方法,不过这需要一定的学习成本。本文旨在详细解读einsum方法的原理,并给出一些基本示例。

一、爱因斯坦求和

爱因斯坦求和是一种对求和公式简洁高效的记法,其原则是当变量下标重复出现时,即可省略繁琐的求和符号。

比如求和公式:
∑ i = 1 n a i b i = a 1 b 1 + a 2 b 2 + . . . + a n b n \begin{aligned} \sum_{i=1}^na_ib_i=a_1b_1+a_2b_2+...+a_nb_n \end{aligned} i=1naibi=a1b1+a2b2+...+anbn
其中变量 a a a与变量 b b b的下标是相同的(即重复出现),则可将其记为:
a i b i = ∑ i = 1 n a i b i \begin{aligned} a_ib_i=\sum_{i=1}^na_ib_i \end{aligned} aibi=i=1naibi

二、einsum方法原理

einsum方法正是利用了爱因斯坦求和简介高效的表示方法,从而可以驾驭任何复杂的矩阵计算操作。基本的框架如下:

C = einsum('ij,jk->ik', A, B)

上述操作表示矩阵A与矩阵B的点积。输入的参数分为两部分,前面表示计算操作的字符串,后面是以逗号隔开的操作对象(数量需与前面对应)。其中在计算操作表示中,"->“左边是以逗号隔开的下标索引,重复出现的索引即是需要爱因斯坦求和的;”->"右边的是最后输出的结果形式。
以上式为例,其计算公式为: C i k = ∑ j A i j B j k C_{ik} = \sum_jA_{ij}B_{jk} Cik=jAijBjk,其等价于矩阵A与B的点积。
这里有几条原则需要注意,之后也会和结合示例进行详解:

  1. "->"左边的是对应维度,以逗号隔开
  2. "->"右边的是最终output的形式
  3. 如果符号"->"被省略则代表输出为整体求和
  4. "…"表示省略之前或之后的所有维度
  5. einsum中涉及到的计算操作有很多,包括但不限于点积、对应元素相乘、求和、转置等

三、一些示例

einsum方法在numpy和pytorch中均有内置,这里以pytorch为例,首先定义一些需要用到的变量:

import torch
from torch import einsum
a = torch.ones(3,4)
b = torch.ones(4,5)
c = torch.ones(6,7,8)
d = torch.ones(3,4)
x, y = torch.randn(5), torch.randn(5)
  1. 计算矩阵所有元素之和
einsum('i,j', a)   # 等价于einsum('i,j->', a)
einsum('i,j,k', c)
  1. 计算矩阵的迹
einsum('ii', a)
  1. 获取矩阵对角线元素组成的向量
einsum('ii->i', a)
  1. 向量相乘得到矩阵
einsum('i,j->ij', x, y)
  1. 矩阵点积
einsum('ij,jk->ik', a, b)
  1. 矩阵对应元素相乘
einsum('ij,ij->ij', a, d)
  1. 矩阵的转置
einsum('ijk->ikj', c)
einsum('...jk->...kj', c)  # 两种形式等价
  1. 双线性运算
A = torch.randn(3,5,4)
l = torch.randn(2,5)
r = torch.randn(2,4)
torch.einsum('bn,anm,bm->ba', l, A, r)
  1. 最后来一个复杂的
a = torch.randn(3,4,5)
b = torch.randn(6,5)
c = torch.randn(6,3)
target = einsum('fti,di,df->dt', a,b,c)

# 把上面的求和过程写成循环的形式方便理解
# 对f和i进行求和
dt = torch.zeros(6,4)
for d in range(6):
	for t in range(4):
		tmp = 0
		for f in range(3):
			for i in range(5):
				tmp += a[f,t,i].item() * b[d,i].item() * c[d,f].item()
		dt[d,t] = tmp

print(target)
print('\n', dt)

# 结果:
tensor([[ 0.5712, -0.9473,  0.1972,  0.6474],
        [ 0.6273, -1.3820,  2.0779, -0.5729],
        [ 4.0083, -2.9979,  0.9493,  0.8467],
        [15.9985, -7.1897,  1.1957,  2.7394],
        [ 0.4298, -0.6509,  0.5128,  1.7027],
        [-4.6843, -4.7948,  1.1389, -6.9389]])
        
tensor([[ 0.5712, -0.9473,  0.1972,  0.6474],
        [ 0.6273, -1.3820,  2.0779, -0.5729],
        [ 4.0083, -2.9979,  0.9493,  0.8467],
        [15.9985, -7.1897,  1.1957,  2.7394],
        [ 0.4298, -0.6509,  0.5128,  1.7027],
        [-4.6843, -4.7948,  1.1389, -6.9389]])
  • 30
    点赞
  • 89
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
引用\[2\]和\[3\]提供了集成Mybatis-Plus的方式,而Mybatis-Plus是一个增强版的Mybatis框架,提供了更多的便利功能和增强特性。Mybatis-Plus中的方法主要包括CRUD操作和条件构造器。 在Mybatis-Plus中,CRUD操作的方法包括insert、delete、update和select。这些方法可以通过继承BaseMapper接口来使用,或者使用Mybatis-Plus提供的通用Mapper接口。通用Mapper接口提供了一系列的方法,如selectById、selectList、insert、updateById等,可以直接使用这些方法进行数据库操作。 除了CRUD操作,Mybatis-Plus还提供了条件构造器来方便地构建复杂的查询条件。条件构造器包括QueryWrapper、UpdateWrapper和LambdaQueryWrapper等,可以通过链式调用的方式来构建查询条件,如eq、like、in等。这些条件构造器可以与CRUD操作的方法一起使用,以实现更加灵活和精确的查询。 总结起来,Mybatis-Plus提供了一系列的方法来简化和增强Mybatis的使用,包括CRUD操作和条件构造器。通过集成Mybatis-Plus,可以更加方便地进行数据库操作和查询。 #### 引用[.reference_title] - *1* *2* *3* [Mybatis-Plus详解](https://blog.csdn.net/bier_zm/article/details/125808590)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值