Pytorch中张量矩阵乘法函数(mm, bmm, matmul)使用说明,含高维张量实例及运行结果

本文详细介绍了在PyTorch中使用torch.mm(),torch.bmm()和torch.matmul()进行矩阵乘法的方法,包括函数定义、参数、示例以及它们在处理不同维度张量时的行为和广播机制。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1 torch.mm() 函数

全称为matrix-matrix product,对输入的张量做矩阵乘法运算,输入输出维度一定是2维

1.1 torch.mm() 函数定义及参数

torch.bmm(input, mat2, , out=None) → Tensor
input (Tensor) – – 第一个要相乘的矩阵
** mat2
* (Tensor) – – 第二个要相乘的矩阵
不支持广播到通用形状、类型推广以及整数、浮点和复杂输入。

1.2 torch.bmm() 官方示例

mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
torch.mm(mat1, mat2)

tensor([[ 0.4851,  0.5037, -0.3633],
        [-0.0760, -3.6705,  2.4784]])

2 torch.bmm() 函数

全称为batch matrix-matrix product,对输入的张量做矩阵乘法运算,输入输出维度一定是3维;

2.1 torch.bmm() 函数定义及参数

torch.bmm(input, mat2, , out=None) → Tensor
input (Tensor) – – 第一批要相乘的矩阵
** mat2
* (Tensor) – – 第二批要相乘的矩阵
不支持广播到通用形状、类型推广以及整数、浮点和复杂输入。

2.2 torch.bmm() 官方示例

input = torch.randn(10, 3, 4)
mat2 = torch.randn(10, 4, 5)
res = torch.bmm(input, mat2)
res.size()

torch.Size([10, 3, 5])

3 torch.matmul() 函数

可进行多维矩阵运算,根据不同输入维度进行广播机制然后运算,和点积类似,广播机制可参考之前博文torch.mul()函数

3.1 torch.matmul() 函数定义及参数

torch.matmul(input, mat2, , out=None) → Tensor
input (Tensor) – – 第一个要相乘的张量
** mat2
* (Tensor) – – 第二个要相乘的张量
支持广播到通用形状、类型推广以及整数、浮点和复杂输入。

3.2 torch.matmul() 规则约定

(1)若两个都是1D(向量)的,则返回两个向量的点积;

(2)若两个都是2D(矩阵)的,则按照(矩阵相乘)规则返回2D;

(3)若input维度1D,other维度2D,则先将1D的维度扩充到2D(1D的维数前面+1),然后得到结果后再将此维度去掉,得到的与input的维度相同。即使作扩充(广播)处理,input的维度也要和other维度做对应关系;

(4)若input是2D,other是1D,则返回两者的点积结果;

(5)如果一个维度至少是1D,另外一个大于2D,则返回的是一个批矩阵乘法( a batched matrix multiply)

  • (a)若input是1D,other是大于2D的,则类似于规则(3);
  • (b)若other是1D,input是大于2D的,则类似于规则(4);
  • (c)若input和other都是3D的,则与torch.bmm()函数功能一样;
  • (d)如果input中某一维度满足可以广播(扩充),那么也是可以进行相乘操作的。例如 input(j,1,n,m)* other (k,m,p) = output(j,k,n,p)

matmul() 根据输入矩阵自动决定如何相乘。低维根据高维需求,合理广播。

3.3 torch.matmul() 官方示例

# vector x vector
tensor1 = torch.randn(3)
tensor2 = torch.randn(3)
torch.matmul(tensor1, tensor2).size()

torch.Size([])
# matrix x vector
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()

torch.Size([3])
# batched matrix x broadcasted vector
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()

torch.Size([10, 3])
# batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
torch.matmul(tensor1, tensor2).size()

torch.Size([10, 3, 5])
# batched matrix x broadcasted matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
torch.matmul(tensor1, tensor2).size()

torch.Size([10, 3, 5])

3.4 高维数据实例解释

直接看一个4维的二值例子,先看图(红虚线和实线是为了便于区分维度而添加),不懂再结合代码和结果分析,先做广播,然后对应矩阵进行乘积运算
在这里插入图片描述

代码如下:

import torch
import numpy as np

np.random.seed(2022)
a = np.random.randint(low=0, high=2, size=(2, 2, 3, 4))
a = torch.tensor(a)
b = np.random.randint(low=0, high=2, size=(2, 1, 4, 3))
b = torch.tensor(b)
c = torch.matmul(a, b)
# or
# c = a @ b
print(a)
print("=============================================")
print(b)
print("=============================================")
print(c.size())
print("=============================================")
print(c)

运行结果为:

tensor([[[[1, 0, 1, 0],
          [1, 1, 0, 1],
          [0, 0, 0, 0]],

         [[1, 1, 1, 1],
          [1, 1, 0, 0],
          [0, 1, 0, 1]]],


        [[[0, 0, 0, 1],
          [0, 0, 0, 1],
          [0, 1, 0, 0]],

         [[1, 1, 1, 1],
          [1, 1, 1, 1],
          [0, 0, 0, 0]]]], dtype=torch.int32)
=============================================
tensor([[[[0, 1, 0],
          [1, 1, 0],
          [0, 0, 0],
          [1, 1, 0]]],


        [[[0, 1, 0],
          [1, 1, 1],
          [1, 1, 1],
          [1, 0, 1]]]], dtype=torch.int32)
=============================================
torch.Size([2, 2, 3, 3])
=============================================
tensor([[[[0, 1, 0],
          [2, 3, 0],
          [0, 0, 0]],

         [[2, 3, 0],
          [1, 2, 0],
          [2, 2, 0]]],


        [[[1, 0, 1],
          [1, 0, 1],
          [1, 1, 1]],

         [[3, 3, 3],
          [3, 3, 3],
          [0, 0, 0]]]], dtype=torch.int32)

参考博文及感谢

部分内容参考以下链接,这里表示感谢 Thanks♪(・ω・)ノ
参考博文1 官方文档查询地址
https://pytorch.org/docs/stable/index.html
参考博文2 Pytorch矩阵乘法之torch.mul() 、 torch.mm() 及torch.matmul()的区别
https://blog.csdn.net/irober/article/details/113686080

<think>好的,我需要帮助用户了解PyTorch中的张量乘法方法。首先,我得回忆一下PyTorch提供的几种乘法方式,比如torch.mmtorch.matmultorch.dot、torch.mul以及运算符@和*的区别。 用户可能混淆这些方法的应用场景,所以我要分清楚每种方法的用途。例如,torch.dot只适用于一维向量,而torch.mm用于二维矩阵乘法。还有元素级乘法torch.mul或*运算符。 接下来,我需要检查引用内容。引用1提到不同乘法方法的应用,引用2和4给出了示例代码,特别是torch.dot和哈达玛积的例子。引用3介绍了张量的基本概念,可能作为背景知识。 需要确保回答结构清晰,分点说明每种方法,并给出示例代码。同时,要符合用户要求的格式:行内公式用$...$,独立公式用$$...$$,正确的中文回答,生成相关问题,并在引用处添加标识。 还要注意代码块的正确格式,比如使用三重反引号标注Python代码。最后,生成的相关问题需要与PyTorch张量乘法相关,帮助用户深入理解。 现在,组织内容: 1. 点积(torch.dot):仅一维向量,结果标量。 2. 矩阵乘法torch.mm):二维矩阵,严格尺寸匹配。 3. 批量矩阵乘法torch.bmm):三维张量,批量处理。 4. 张量乘法torch.matmul):更灵活,支持广播。 5. 元素级乘法torch.mul或*):哈达玛积。 6. 运算符@和*的区别。 每个部分给出代码示例,并引用对应的引用内容。最后生成相关问题,如不同乘法应用场景、处理不同维度的张量、性能差异等。 检查是否有遗漏,比如是否提到广播机制,或者在torch.matmul中的不同维度处理。确保每个示例正确无误,比如使用正确的函数和参数。 最后,确保所有引用标识正确添加,比如引用[^1]和[^4]在相关段落末尾。同时,避免使用Markdown格式,保持自然的中文表达。</think>PyTorch中的张量乘法多种形式,每种方法对应不同的数学运算场景。以下是主要方法及示例说明: ### 1. **点积(torch.dot)** - **用途**:仅适用于一维向量,计算对应元素乘积之和。 - **公式**:$ \mathbf{a} \cdot \mathbf{b} = \sum_{i=1}^n a_i b_i $ - **示例**: ```python vector1 = torch.tensor([1, 2, 3]) vector2 = torch.tensor([4, 5, 6]) print(torch.dot(vector1, vector2)) # 输出: 1*4 + 2*5 + 3*6 = 32 [^2] ``` ### 2. **矩阵乘法torch.mm)** - **用途**:严格二维矩阵乘法,要求尺寸匹配。 - **公式**:若矩阵$A$为$m \times n$,矩阵$B$为$n \times p$,则结果矩阵$C$为$m \times p$,其中$C_{ij} = \sum_{k=1}^n A_{ik} B_{kj}$。 - **示例**: ```python matrix1 = torch.tensor([[1, 2], [3, 4]]) matrix2 = torch.tensor([[5, 6], [7, 8]]) print(torch.mm(matrix1, matrix2)) # 输出: [[19, 22], [43, 50]] ``` ### 3. **批量矩阵乘法torch.bmm)** - **用途**:三维张量的批量乘法,适用于多个二维矩阵的并行计算。 - **示例**: ```python batch1 = torch.randn(3, 2, 4) # 3个2x4矩阵 batch2 = torch.randn(3, 4, 3) # 3个4x3矩阵 result = torch.bmm(batch1, batch2) # 输出形状为(3, 2, 3) ``` ### 4. **通用张量乘法torch.matmul)** - **用途**:支持多维张量,自动处理广播机制,适用性最广。 - **示例**: ```python # 二维矩阵乘法 print(torch.matmul(matrix1, matrix2)) # 同torch.mm # 高维张量广播乘法 tensor1 = torch.randn(2, 3, 4) tensor2 = torch.randn(4, 5) print(torch.matmul(tensor1, tensor2).shape) # 输出: (2, 3, 5) ``` ### 5. **元素级乘法torch.mul或*)** - **用途**:哈达玛积(Hadamard product),对应元素相乘,要求张量形状一致。 - **公式**:$(A \odot B)_{ij} = A_{ij} B_{ij}$ - **示例**: ```python data1 = torch.tensor([[1, 2], [3, 4]]) data2 = torch.tensor([[5, 6], [7, 8]]) print(data1 * data2) # 输出: [[5, 12], [21, 32]] [^4] ``` ### 6. **运算符@与*的区别** - `@`等价于`torch.matmul`,用于矩阵乘法。 - `*`等价于`torch.mul`,用于元素级乘法。 ```python print(matrix1 @ matrix2) # 矩阵乘法 print(matrix1 * matrix2) # 元素乘法 ``` ###
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值