torch.einsum()
是 PyTorch 中一个功能强大的函数,用于执行爱因斯坦求和约定(Einstein summation convention)。这个函数可以用来进行多维数组(张量)间的各种操作,如矩阵乘法、内积、外积、转置和更多复杂的张量操作。
函数签名
torch.einsum(equation, *operands)
参数
- equation:一个字符串,描述了操作的模式和求和的方式。格式如下:
- 逗号分隔的子字符串表示每个输入张量。
- 箭头
->
右侧的子字符串表示输出张量的模式(如果省略,默认会根据输入张量和求和规则自动推断)。 - 子字符串中的每个字母表示对应维度标签。相同字母的维度会被求和(如果没有箭头部分)。
- operands:一个或多个输入张量,这些张量将根据
equation
指定的模式进行操作。
返回值
- 返回一个张量,是根据
equation
指定的模式进行计算的结果。
用法示例
示例 1:矩阵乘法
import torch
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
C = torch.einsum('ij,jk->ik', A, B)
print(C)
"""
# Output
tensor([[19, 22],
[43, 50]])
"""
解释:
'ij'
表示矩阵A
,'jk'
表示矩阵B
。- 矩阵乘法的结果是将
A
的列和B
的行配对求和,结果是一个新的矩阵C
,其模式为'ik'
。
示例 2:内积
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
c = torch.einsum('i,i->', a, b)
print(c)
"""
# Output
tensor(32)
"""
解释:
'i'
表示向量a
和b
的维度标签。- 内积操作对两个向量相应位置的元素进行乘积并求和,结果是一个标量。
示例 3:外积
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
c = torch.einsum('i,j->ij', a, b)
print(c)
"""
# Output
tensor([[ 4, 5, 6],
[ 8, 10, 12],
[12, 15, 18]])
"""
解释:
'i'
表示向量a
的维度标签,'j'
表示向量b
的维度标签。- 外积操作生成一个矩阵,其中元素是
a
和b
相应位置元素的乘积。
其他说明
torch.einsum()
根据提供的爱因斯坦求和约定字符串,对输入张量进行操作。这个字符串定义了输入张量的维度如何与输出张量的维度相关联。通过这种方式,einsum
可以实现从简单的矩阵乘法到复杂的张量运算的多种操作。
优势
- 灵活性:可以处理各种张量操作,不需要编写复杂的循环。
- 简洁性:使用一个字符串定义操作模式,代码更简洁。
- 性能:高效实现,因为底层操作由优化的库支持。
通过 torch.einsum()
,可以方便地实现复杂的张量操作,尤其适用于需要多维度操作和求和的场景。
具体例子
x = torch.einsum('vn,bfnt->bfvt',(A,x))
这段代码使用了 torch.einsum()
函数,通过爱因斯坦求和约定实现了张量操作。
'vn,bfnt->bfvt'
可以分解为以下几部分:
'vn'
:表示张量A
的维度标签。'bfnt'
:表示张量x
的维度标签。'->bfvt'
:表示输出张量的维度标签。
张量解释
这里:
V
、N
、B
、F
、T
是正整数,表示张量的不同维度。vn
表示A
的第一个维度是v
,第二个维度是n
。bfnt
表示x
的四个维度分别是b
、f
、n
和t
。
操作过程
-
求和与乘积:
- 根据
vn
和bfnt
,torch.einsum('vn,bfnt->bfvt', (A, x))
会计算A
和x
的乘积并对n
维度求和。 A
的n
维度和x
的n
维度会匹配,意味着我们对n
维度上的元素进行乘积并求和。
- 根据
-
输出维度:
- 乘积和求和操作后,
n
维度会被消去,A
的v
维度和x
的b
、f
、t
维度会保留下来。 - 最终输出张量的形状是
(B, F, V, T)
,即维度标签为bfvt
。
- 乘积和求和操作后,
示例
假设具体形状如下:
A
的形状为(3, 4)
,表示V = 3
,N = 4
。x
的形状为(2, 5, 4, 6)
,表示B = 2
,F = 5
,N = 4
,T = 6
。
代码如下:
import torch
# 定义张量 A 和 x
A = torch.randn(3, 4) # 形状 (V, N)
x = torch.randn(2, 5, 4, 6) # 形状 (B, F, N, T)
# 使用 torch.einsum 进行张量操作
result = torch.einsum('vn,bfnt->bfvt', (A, x))
# 输出结果形状
print(result.shape) # 输出 (2, 5, 3, 6)
解释:
A
和x
的n
维度是 4,进行乘积并对该维度求和。- 结果是一个形状为
(2, 5, 3, 6)
的张量。
最后
通过 torch.einsum('vn,bfnt->bfvt', (A, x))
,可对张量 A
和 x
进行了一个指定维度的乘积和求和操作,最终生成了一个新的张量,其形状由输入张量的其他维度和公式中的约定确定。