Pytorch:torch.mul() 、torch.mm()、torch.bmm()、torch.matmul()

1 torch.mul() 

  • 用标量值value乘以输入input的每个元素,并返回一个新的结果张量。 \( out=tensor ∗ value \)。如果输入是FloatTensor or DoubleTensor类型,则value 必须为实数,否则须为整数。

torch.mul(input, value, out=None)
参数描述
input (Tensor)输入张量
value (Number)  乘到每个元素的数
out (Tensor)可选,输出张量

栗子:

>>> a = torch.randn(3)
>>> a

-0.9374
-0.5254
-0.6069
[torch.FloatTensor of size 3]

>>> torch.mul(a, 100)

-93.7411
-52.5374
-60.6908
[torch.FloatTensor of size 3]
  • 两个张量input,other按元素进行相乘,并返回到输出张量。即计算\( out_i=input_i ∗ other_i \)。两计算张量形状不须匹配,但总元素数须一致。

torch.mul(input, other, out=None)
参数描述
input (Tensor)第一个相乘张量
other (Tensor) 第二个相乘张量
out (Tensor)可选,输出张量

栗子:

>>> a = torch.randn(4,4)
>>> a

-0.7280  0.0598 -1.4327 -0.5825
-0.1427 -0.0690  0.0821 -0.3270
-0.9241  0.5110  0.4070 -1.1188
-0.8308  0.7426 -0.6240 -1.1582
[torch.FloatTensor of size 4x4]

>>> b = torch.randn(2, 8)
>>> b

 0.0430 -1.0775  0.6015  1.1647 -0.6549  0.0308 -0.1670  1.0742
-1.2593  0.0292 -0.0849  0.4530  1.2404 -0.4659 -0.1840  0.5974
[torch.FloatTensor of size 2x8]

>>> torch.mul(a, b)

-0.0313 -0.0645 -0.8618 -0.6784
 0.0934 -0.0021 -0.0137 -0.3513
 1.1638  0.0149 -0.0346 -0.5068
-1.0304 -0.3460  0.1148 -0.6919
[torch.FloatTensor of size 4x4]

2 torch.mm()

处理二维矩阵的乘法,而且也只能处理二维矩阵,其他维度要用torch.matmul()。torch.mm(a, b)是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵

torch.mm(input, mat2, out=None)
  • 栗子
mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
torch.mm(mat1, mat2)

tensor([[ 0.4851,  0.5037, -0.3633],
        [-0.0760, -3.6705,  2.4784]])

3 torch.bmm()

torch.bmm(input, mat2, out=None)

看函数名就知道,在torch.mm的基础上加了个batch计算,不能广播。

4 torch.matmul()

torch.matmul(input, other, out=None)

功能
适用性最多的,能处理batch、广播的矩阵:

  1. 如果第一个参数是一维,第二个是二维,那么给第一个提供一个维度
  2. 如果第一个是二维,第二个是一维,就是矩阵乘向量
  3. 带有batch的情况,可保留batch计算
  4. 维度不同时,可先广播,再batch计算

栗子:

  • vector x vector
tensor1 = torch.randn(3)
tensor2 = torch.randn(3)
torch.matmul(tensor1, tensor2).size()
torch.Size([])
  • matrix x vector
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()
torch.Size([3])
  • batched matrix x broadcasted vecto
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
  • batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])

总结:

对位相乘用torch.mul,二维矩阵乘法用torch.mm,batch二维矩阵用torch.bmm,batch、广播用torch.matmul

参考:

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值