torch.mul、torch.mm、torch.bmm、torch.matmul的区别

torch.mul

torch.mul(input, other, out=None)

功能

对位相乘,可以广播

该函数能处理两种情况

  1. input是矩阵/向量,other是标量
    这个时候是就是input的所有元素乘上other
  2. input是矩阵/向量,other是矩阵/向量
    这时 o u t i = i n p u t i × o t h e r i out_i = input_i \times other_i outi=inputi×otheri,对位相乘,如果两个都是向量,则可以广播的

例子

  1. input和other的size相同的对位相乘

    a: tensor([[ 1.8351,  2.1536],
        [-0.8320, -1.4578]])
    b: tensor([[2.9355, 0.3450],
        [0.5708, 1.9957]])
    c = torch.mul(a,b):
     tensor([[ 5.3869,  0.7429],
        [-0.4749, -2.9093]])
    
  2. 两个向量的广播

    a: tensor([[ 1.8351,  2.1536],
            [-0.8320, -1.4578]])
    b: tensor([[2.9355, 0.3450],
            [0.5708, 1.9957]])
    c = torch.mul(a,b):
     tensor([[ 5.3869,  0.7429],
            [-0.4749, -2.9093]])
    

torch.mm

torch.mm(input, mat2, out=None)

解决的问题

处理二维矩阵的乘法,而且也只能处理二维矩阵,其他维度要用torch.matmul

例子

mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
torch.mm(mat1, mat2)
tensor([[ 0.4851,  0.5037, -0.3633],
        [-0.0760, -3.6705,  2.4784]])

torch.bmm

torch.bmm(input, mat2, out=None)

看函数名就知道,在torch.mm的基础上加了个batch计算,不能广播


torch.matmul

torch.matmul(input, other, out=None)

功能
适用性最多的,能处理batch、广播的矩阵:

  1. 如果第一个参数是一维,第二个是二维,那么给第一个提供一个维度
  2. 如果第一个是二维,第二个是一维,就是矩阵乘向量
  3. 带有batch的情况,可保留batch计算
  4. 维度不同时,可先广播,再batch计算

例子

  1. vector x vector

    tensor1 = torch.randn(3)
    tensor2 = torch.randn(3)
    torch.matmul(tensor1, tensor2).size()
    torch.Size([])
    
  2. matrix x vector

    tensor1 = torch.randn(3, 4)
    tensor2 = torch.randn(4)
    torch.matmul(tensor1, tensor2).size()
    torch.Size([3])
    
  3. batched matrix x broadcasted vecto

    tensor1 = torch.randn(10, 3, 4)
    tensor2 = torch.randn(4)
    torch.matmul(tensor1, tensor2).size()
    torch.Size([10, 3])
    
  4. batched matrix x batched matrix

    tensor1 = torch.randn(10, 3, 4)
    tensor2 = torch.randn(10, 4, 5)
    torch.matmul(tensor1, tensor2).size()
    torch.Size([10, 3, 5])
    

总结

对位相乘用torch.mul,二维矩阵乘法用torch.mm,batch二维矩阵用torch.bmm,batch、广播用torch.matmul

### PyTorch 中张量矩阵乘法的用法及维度匹配规则 在 PyTorch 中,矩阵乘法可以通过 `torch.mm` `torch.matmul` 来实现。对于批量矩阵乘法,则可以使用 `torch.bmm` 或者更通用的 `torch.matmul` 函数。 #### 基本矩阵乘法规则 当执行标准矩阵乘法时,假设有一个形状为 `(m, n)` 的矩阵 A 一个形状为 `(n, p)` 的矩阵 B,那么它们相乘的结果 C 将是一个形状为 `(m, p)` 的矩阵[^1]: ```python import torch tensor1 = torch.randn(3, 4) # 形状为 (3, 4) tensor2 = torch.randn(4, 5) # 形状为 (4, 5) result = torch.mm(tensor1, tensor2) # 结果形状应为 (3, 5) print(result.shape) ``` #### 批量矩阵乘法规则 如果要处理多个矩阵的同时乘法(即批量矩阵),比如有两组大小相同的矩阵集合 X Y,每组都包含 b 个矩阵,其中每个矩阵分别为 m×n n×p 大小,那么最终得到的是由 b 个 m×p 矩阵组成的输出 Z: ```python batch_tensor1 = torch.randn(2, 3, 4) # 形状为 (2, 3, 4),表示有两个 3x4 的矩阵 batch_tensor2 = torch.randn(2, 4, 5) # 形状为 (2, 4, 5),表示有两个 4x5 的矩阵 batch_result = torch.bmm(batch_tensor1, batch_tensor2) # 输出形状应为 (2, 3, 5) print(batch_result.shape) ``` #### 维度广播机制下的矩阵乘法 除了上述两种情况外,在某些情况下即使输入张量不完全满足严格的尺寸要求也可以通过自动扩展来进行运算。例如,当其中一个操作数是二维而另一个是一维时,一维的操作数会被视为具有额外的一批或列/行来适应另一方;或者当两个三维张量的最后一维倒数第二维分别相同的时候也能正常工作。 ```python matrix = torch.randn(3, 4) # 形状为 (3, 4) vector = torch.randn(4) # 形状为 (4,) broadcasted_mul = matrix @ vector # 这里会把 vector 当作 (4, 1) 来对待,结果形状为 (3,) print(broadcasted_mul.shape) ``` 需要注意的是,为了使这些函数能够成功运行而不抛出错误,参与运算的对象之间的相应轴长度必须严格一致或者是其中之一等于1以便于广播。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值