Pytorch常用乘法函数总结:torch.mul()、*、torch.mm()、torch.bmm()、torch.mv()、torch.dot()、@、torch.matmul()

上一篇博客总结了numpy中常用的乘法函数:

numpy常用乘法函数总结:np.dot()、np.multiply()、*、np.matmul()、@、np.prod()、np.outer()-CSDN博客

主要是 np.dot()、np.multiply()、*、np.matmul()、@ 五种,其中 np.matmul() 和 @ 完全等价,np.multiply() 和 * 在输入数据类型为 np.array 时也完全等价

本文总结Pytorch常用的乘法函数,TensorFlow常用乘法函数放在下一篇文章里

目录

torch.mul() 和 *【等价,element-wise乘,可广播】

广播操作

torch.mm() 或 torch.bmm() 【矩阵乘法,前二维后三维,均不可广播】

torch.mv()【矩阵-向量乘法,不可广播】

torch.dot()【仅支持两个一维向量点积】

@【等价于 torch.dot() + torch.mv() + torch.mm()】

torch.matmul() 【矩阵乘法,可高维,可广播】

参考


torch.mul() 和 *【等价,element-wise乘,可广播】

torch.mul(x, y) 等价于 x*y 

  • 矩阵 * 标量:矩阵中的每个元素都乘以标量
  • 矩阵 * 行向量:要求矩阵的列数 = 行向量的列数
  • 矩阵 * 列向量:要求矩阵的行数 = 列向量的行数
  • 矩阵 * 矩阵:要求两个矩阵的维度完全相同,不同的话看是否能广播成相同shape
vec1 = torch.arange(4)
vec2 = torch.tensor([4,3,2,1])
mat1 = torch.arange(12).reshape(4,3)
mat2 = torch.arange(12).reshape(3,4)

print(vec1 * vec2)
print(mat2 * vec1)
print(mat1 * mat1)

Output:
tensor([0, 3, 4, 3])
tensor([[ 0,  1,  4,  9],
        [ 0,  5, 12, 21],
        [ 0,  9, 20, 33]])
tensor([[  0,   1,   4],
        [  9,  16,  25],
        [ 36,  49,  64],
        [ 81, 100, 121]])

一句话总结:如果两者shape相同,直接element-wise相乘;如果两者shape不同,本质上是先把另一个扩展/复制成相同的shape(即广播操作),再对应元素相乘 

广播操作

(1)在一定的规则下允许高维Tensor和低维Tensor之间的运算。这里举一个例子:a是二维Tensor,b是三维Tensor,但是a的维度与b的后两位相同,那么a和b仍然可以做 * 操作,结果是一个和b维度一样的三维Tensor。可以理解为沿着b的第0维做二维Tensor点积,或运算前将a沿着b的第0维进行了expand操作

import torch

a = torch.tensor([[1, 2], [2, 3]])   # torch.Size([2, 2])
b = torch.tensor([[[1, 2], [2, 3]], [[-1, -2], [-2, -3]]])   # torch.Size([2, 2, 2])

print(a*b)    # torch.Size([2, 2, 2])   b*a 结果相同
'''
tensor([[[ 1,  4],
         [ 4,  9]],

        [[-1, -4],
         [-4, -9]]])
'''

运算过程可以理解为:

a = a.expand(b.size())
'''
tensor([[[1, 2],
         [2, 3]],

        [[1, 2],
         [2, 3]]])
'''

a*b

(2)再举个例子:两个Tensor都是3维,但shape不一致,这种情况看是否能够广播成一致的shape

import torch

node = torch.tensor([[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]])   # shape=(2,5)

node_0 = node.unsqueeze(-1)   # torch.Size([2, 5, 1])
'''
tensor([[[1],
         [1],
         [1],
         [0],
         [0]],

        [[1],
         [1],
         [1],
         [1],
### PyTorch 中张量矩阵乘法的用法及维度匹配规则 在 PyTorch 中,矩阵乘法可以通过 `torch.mm` 和 `torch.matmul` 来实现。对于批量矩阵乘法,则可以使用 `torch.bmm` 或者更通用的 `torch.matmul` 函数。 #### 基本矩阵乘法规则 当执行标准矩阵乘法时,假设有一个形状为 `(m, n)` 的矩阵 A 和一个形状为 `(n, p)` 的矩阵 B,那么它们相乘的结果 C 将是一个形状为 `(m, p)` 的矩阵[^1]: ```python import torch tensor1 = torch.randn(3, 4) # 形状为 (3, 4) tensor2 = torch.randn(4, 5) # 形状为 (4, 5) result = torch.mm(tensor1, tensor2) # 结果形状应为 (3, 5) print(result.shape) ``` #### 批量矩阵乘法规则 如果要处理多个矩阵的同时乘法(即批量矩阵),比如有两组大小相同的矩阵集合 X 和 Y,每组都包含 b 个矩阵,其中每个矩阵分别为 m×n 和 n×p 大小,那么最终得到的是由 b 个 m×p 矩阵组成的输出 Z: ```python batch_tensor1 = torch.randn(2, 3, 4) # 形状为 (2, 3, 4),表示有两个 3x4 的矩阵 batch_tensor2 = torch.randn(2, 4, 5) # 形状为 (2, 4, 5),表示有两个 4x5 的矩阵 batch_result = torch.bmm(batch_tensor1, batch_tensor2) # 输出形状应为 (2, 3, 5) print(batch_result.shape) ``` #### 维度广播机制下的矩阵乘法 除了上述两种情况外,在某些情况下即使输入张量不完全满足严格的尺寸要求也可以通过自动扩展来进行运算。例如,当其中一个操作数是二维而另一个是一维时,一维的操作数会被视为具有额外的一批或列/行来适应另一方;或者当两个三维张量的最后一维和倒数第二维分别相同的时候也能正常工作。 ```python matrix = torch.randn(3, 4) # 形状为 (3, 4) vector = torch.randn(4) # 形状为 (4,) broadcasted_mul = matrix @ vector # 这里会把 vector 当作 (4, 1) 来对待,结果形状为 (3,) print(broadcasted_mul.shape) ``` 需要注意的是,为了使这些函数能够成功运行而不抛出错误,参与运算的对象之间的相应轴长度必须严格一致或者是其中之一等于1以便于广播。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Cheer-ego

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值