【PyTorch函数解析】einsum的用法示例

一、前言

einsum 是一个非常强大的函数,用于执行张量(Tensor)运算。它的名称来源于爱因斯坦求和约定(Einstein summation convention),在PyTorch中,einsum 可以方便地进行多维数组的操作和计算。

在Transfomer中,einsum用的非常多,比如使用 einsum 实现自注意力机制中注意力权重的获取,也就是Q和K的内积:

  • Q(Query):形状为 (batch_size, seq_len, d_k)

  • K(Key):形状为 (batch_size, seq_len, d_k)

import torch
import torch.nn.functional as F

Q = torch.randn(2, 10, 64)  # (batch_size, seq_len, d_k)
K = torch.randn(2, 10, 64)  # (batch_size, seq_len, d_k)

# (batch_size, seq_len, seq_len)
attention_scores = torch.einsum('bqd,bkd->bqk', Q, K) / torch.sqrt(torch.tensor(64.0))
# (batch_size, seq_len, seq_len)   
attention_weights = F.softmax(attention_scores, dim=-1)  

二、常见用法示例

2.1 向量点积

a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = torch.einsum('i,i->', a, b)
print(result)  # 输出 32

这里,'i,i->' 表示对向量 a 和 b 进行点积操作,其中 i 是索引表示,-> 之后为空表示求和。

2.2 矩阵乘法

A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
result = torch.einsum('ij,jk->ik', A, B)
print(result)  # 输出 tensor([[19, 22], [43, 50]])

这里,'ij,jk->ik' 表示矩阵乘法,其中 i 和 k 是结果的维度,j 是求和维度。

2.3 批量矩阵乘法

A = torch.randn(2, 3, 4)
B = torch.randn(2, 4, 5)
result = torch.einsum('bij,bjk->bik', A, B)

这里,'bij,bjk->bik' 表示对批量的矩阵进行乘法运算。

解释:

bij,bjk分别是A和B的3个维度,用字符串的形式指代。

为什么最后得到的是bik呢?这个和线性代数的矩阵运算规则有关系。

矩阵乘法规则:

  • 给定矩阵 A 的形状为 (m,n)

  • 给定矩阵 B 的形状为 (n,p)

  • 矩阵乘法 A×B 的结果矩阵 C 的形状为 (m,p)

在矩阵乘法中,结果矩阵的每个元素 Cik 是通过 A 的第 i 行和 B 的第 k 列的对应元素相乘并求和得到的,即:

C_{ik}=\sum_{j=1}^nA_{ij}\cdot B_{jk}

计算过程:

1. 匹配批次维度 (b)

  • 对于每个批次,独立进行矩阵乘法运算。

2. 求和维度 (j):

  • j 是两个张量中共同的维度,根据线性代数中的矩阵乘法规则,需要对 j 维度进行求和。

3. 保留和产生的维度:

  • i 来自 A,表示保留 A 的第一个维度。

  • k 来自 B,表示保留 B 的第二个维度。

经过上述分析,einsum 的结果保留了 b(批次维度)、i(来自 A 的第一个维度)和 k(来自 B 的第二个维度)。因此,结果张量的形状为 (batch_size, seq_len_i, seq_len_k),也就是 bik。

同样,延伸到4维计算的话。

torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

首先,假设 queries 和 keys 的形状为:

  • queries: (batch_size, seq_len_q, num_heads, head_dim)

  • keys: (batch_size, seq_len_k, num_heads, head_dim)

用具体变量名表示:

  • n: batch_size,批次大小。

  • q: seq_len_q,查询序列的长度。

  • k: seq_len_k,键序列的长度。

  • h: num_heads,多头注意力中的头数。

  • d: head_dim,每个头的维度。

1. 匹配批次维度 (n) 和头部维度 (h):

  • 批次大小和头部数量在两个输入张量中都是相同的,保持不变。

2. 求和维度 (d):

  • d 表示每个头的维度。在 queries 和 keys 中,d 都是最后一个维度,对这个维度进行点积运算后求和。

3. 保留和产生的维度:

  • q 来自 queries,表示查询序列的长度。

  • k 来自 keys,表示键序列的长度。

所以最后是nhqk。

2.4 转置操作

A = torch.tensor([[1, 2, 3], [4, 5, 6]])
result = torch.einsum('ij->ji', A)
print(result)  # 输出 tensor([[1, 4], [2, 5], [3, 6]])

这里,'ij->ji' 表示将矩阵进行转置操作。

  • 14
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch中的`torch.einsum`函数是一个用于执行张量运算的强大工具。它可以根据指定的公式对输入张量进行操作,并生成输出张量。 引用\[1\]中提供了一些常见的用法示例。例如,可以使用`torch.einsum`计算矩阵的行和、列和以及某个维度的和。例如,可以使用`torch.einsum('ij->i', A)`计算矩阵A的行和,使用`torch.einsum('ij->j', A)`计算矩阵A的列和,使用`torch.einsum('ijklmn->n', D)`计算张量D在某个维度上的和。 引用\[2\]中提供了一些更复杂的用法示例。例如,可以使用`torch.einsum('ij,jk->ik', A, B)`计算矩阵A和B的内积,使用`torch.einsum('ij,ik->jk', A, C)`计算矩阵A和C的外积,使用`torch.einsum('ij,jk,lj->jk', A, B, C)`进行多维张量相乘。 引用\[3\]中提供了一个高阶张量运算的示例。在这个示例中,使用`np.einsum('ijk,jil->kl', a, b)`计算了两个3阶张量a和b的乘积,并生成了一个2阶张量o。这个示例中的公式解析为对i和j进行求和,然后将结果存储在输出张量的k和l位置上。 总之,`torch.einsum`是一个非常灵活和强大的函数,可以用于执行各种张量运算。它可以根据指定的公式对输入张量进行操作,并生成输出张量。 #### 引用[.reference_title] - *1* *2* [【Pytorch写代码技巧--EinsumEinsum详解+常用写法](https://blog.csdn.net/ccaoshangfei/article/details/126995397)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [Pytorch中, torch.einsum详解。](https://blog.csdn.net/a2806005024/article/details/96462827)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值