torch.mm() torch.sparse.mm() torch.bmm() torch.mul() torch.matmul()的区别

torch.mm()

二维矩阵的乘法,假设输入矩阵mat1维度是 ( m × n ) (m×n) (m×n)​​​,矩阵mat2维度是 ( n × p ) (n×p) (n×p)​​​​,则输出维度为 ( m × p ) (m×p) (m×p)​,只能是二维的​​

mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
out = torch.mm(mat1, mat2)
'''
tensor([[ 0.4851,  0.5037, -0.3633],
        [-0.0760, -3.6705,  2.4784]])
'''

torch.sparse.mm()

a是稀疏矩阵,b是稀疏矩阵或者密集矩阵,sparse.mm的作用和torch.mm一样,都是做矩阵乘法计算

a = torch.randn(2, 3).to_sparse().requires_grad_(True)
b = torch.randn(3, 2, requires_grad=True)
y = torch.sparse.mm(a, b)

'''
a:
tensor(indices=tensor([[0, 0, 0, 1, 1, 1],
                       [0, 1, 2, 0, 1, 2]]),
       values=tensor([ 1.5901,  0.0183, -0.6146,  1.8061, -0.0112,  0.6302]),
       size=(2, 3), nnz=6, layout=torch.sparse_coo, requires_grad=True)

b:
tensor([[-0.6479,  0.7874],
        [-1.2056,  0.5641],
        [-1.1716, -0.9923]], requires_grad=True)
        
y:
tensor([[-0.3323,  1.8723],
        [-1.8951,  0.7904]], grad_fn=<SparseAddmmBackward>)
'''

torch.bmm()

torch.mm类似,但多了一个batch_size维度,输入矩阵张量mat1维度是 ( b × m × n ) (b×m×n) (b×m×n),矩阵张量mat2维度是 ( b × n × p ) (b×n×p) (b×n×p),则输出维度为 ( b × m × p ) (b×m×p) (b×m×p)​​

mat1 = torch.randn(10, 3, 4)
mat2 = torch.randn(10, 4, 5)
res = torch.bmm(mat1, mat2)
print(res.size())
# torch.Size([10, 3, 5])

torch.mul()

将输入张量input的每个元素与另一个标量other相乘,返回一个新的张量out,两者维度需满足广播规则

# 方式1:张量 和 标量相乘
a = torch.randn(3)
torch.mul(a, 100)
'''
a: tensor([ 0.2015, -0.4255,  2.6087])
tensor([  20.1494,  -42.5491,  260.8663])
'''

# 方式2:张量 和 张量(需满足广播规则)
a = torch.randn(4, 1)
b = torch.randn(1, 4)
c = torch.mul(a,b)
'''
c:
tensor([[-0.1183, -0.4246, -0.0512,  0.1757],
        [-0.4215, -1.5121, -0.1823,  0.6257],
        [-0.0358, -0.1284, -0.0155,  0.0531],
        [ 0.1649,  0.5917,  0.0713, -0.2448]])
'''

# 方式3:元素对应项相乘
a = torch.randn(3, 2)
b = torch.randn(3, 2)
c = torch.mul(a,b)
'''
C:
tensor([[-1.9259, -0.0116],
        [-1.8523, -0.0392],
        [-0.4881, -0.4235]])
'''

torch.matmul()

两个张量的矩阵乘积。其行为取决于张量的维数如下:

  • 如果两个张量都是一维的,则返回点积(标量)。
  • 如果两个参数都是二维的,则返回矩阵-矩阵乘积。
  • 如果第一个参数是一维的,第二个参数是二维的,则在其维数前加一个1,以实现矩阵乘法。在矩阵相乘之后,附加的维度被删除。
  • 如果第一个参数是二维的,第二个参数是一维的,则返回矩阵向量乘积。
  • 如果两个参数至少是一维的,且至少一个参数是N维的(其中N > 2),则返回一个批处理矩阵乘法。如果第一个参数是一维的,则在其维数前加上1,以便批处理矩阵相乘,然后删除。如果第二个参数是一维的,则为批处理矩阵倍数的目的,将在其维上追加一个1,然后删除它。非矩阵(即批处理)维度是广播的(因此必须是可广播的)
# vector x vector
tensor1 = torch.randn(3)
tensor2 = torch.randn(3)
torch.matmul(tensor1, tensor2).size()
torch.Size([])

# matrix x vector
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()
torch.Size([3])

# batched matrix x broadcasted vector
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])

# batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])

# batched matrix x broadcasted matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])

总结

  • 二维矩阵乘积用torch.mm()torch.sparse.mm()
  • 多批次的二维矩阵之间的乘积用torch.bmm()
  • 标量乘积或对应项乘积用torch.mul()
  • 批次或广播进行乘积用torch.matmul()
  • 11
    点赞
  • 45
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小风_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值