torch.mm()和torch.matmul()

介绍

torch.mm()torch.matmul()都是PyTorch库中用来进行矩阵乘法的函数,但它们在处理输入时有一些不同。

torch.mm()

torch.mm():这个函数仅接受二维矩阵作为输入,并进行矩阵乘法。如果输入的张量不是二维的,它会抛出一个错误。

torch.matmul()

torch.matmul():这个函数可以接受高于二维的张量,并进行适当的广播和矩阵乘法。对于二维矩阵,torch.matmul()torch.mm()的行为是相同的。

例子

二维矩阵

import torch
a = torch.randn(3, 4)
b = torch.randn(4, 5)
c = torch.mm(a, b)  # 使用torch.mm()
d = torch.matmul(a, b)  # 使用torch.matmul()
print(c.shape)  # torch.Size([3, 5])
print(d.shape)  # torch.Size([3, 5])

高维矩阵

e = torch.randn(2, 3, 4)
f = torch.randn(2, 4, 5)
# torch.mm仅接受二维矩阵作为输入
g = torch.mm(e, f)  #报错:RuntimeError: self must be a matrix
g = torch.matmul(e, f)
print(g.shape) # torch.Size([2, 3, 5])

高维广播

除了最后两个维度外,其他的维度相等

x = torch.randn(2, 4, 3, 4)
v = torch.randn(2, 4, 4, 5)
g = torch.matmul(x, v)
print(g.shape)  # torch.Size([2, 4, 3, 5])

除了最后两个维度外,其他的不相等的维度,其中一个是1,会自动广播

x1 = torch.randn(2, 1, 3, 4)
v1 = torch.randn(1, 4, 4, 5)
g1 = torch.matmul(x1, v1)
print(g1.shape)  # torch.Size([2, 4, 3, 5])

除了最后两个维度外,其他的不相等的维度,且不都是1,报错

x2 = torch.randn(3, 4, 3, 4)
v2 = torch.randn(2, 1, 4, 5)
g2 = torch.matmul(x2, v2)
# 张量a的第0维的大小是3,而张量b的第0维的大小是2,这两者不匹配,都不是1,不能广播
print(g2.shape)  # The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 0

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值