pytorch中的矩阵乘法操作:torch.mm(), torch.bmm(), torch.mul()和*, torch.dot(), torch.mv(), @, torch.matmul()

12 篇文章 4 订阅
本文介绍了PyTorch中处理张量乘法的不同方法,包括torch.mm()的矩阵乘法,torch.bmm()的批量矩阵乘法,torch.mul()和*的元素乘法,torch.dot()的向量点积,torch.mv()的矩阵向量乘法,以及@和torch.matmul()的灵活矩阵运算,涵盖了从一维到高维的多种情况,并详细解释了广播法则的应用。
摘要由CSDN通过智能技术生成

😄 无聊整理下torch里的张量的各种乘法相关操作。

0、简单提一下广播法则的定义:

  • 1、让所有输入张量都向其中shape最长的矩阵看齐,shape不足的部分在前面加1补齐。
  • 2、两个张量的维度要么在某一个维度一致,若不一致其中一个维度为1也可广播。否则不能广播。【如两个维度:(4, 1, 4)和(2, 1)可以广播,因为他们不相等的维度其中一个为1就可以广播了。】

1、torch.mm()

- 只适合于二维张量的矩阵乘法。
- m x n, n x p -> m x p

mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 4)
out = torch.mm(mat1, mat2)
out.shape
# torch.Size([2, 4])

2、torch.bmm()

- 只适合于三维张量的矩阵乘法,与torch.mm类似,但多了一个batch_size维度。
- b x m x n, b x n x p -> b x m x p

mat1 = torch.randn(8, 2, 3)
mat2 = torch.randn(8, 3, 4)
out = torch.bmm(mat1, mat2)
out.shape
# torch.Size([8, 2, 4])

3、torch.mul()和*

  • - ⭐ torch.mul()和*等价。
    - 张量对应位置元素相乘。
    - 将输入张量input的每个元素与另一个向量or标量other相乘,返回一个新的张量out,两者维度需满足广播规则
# 方式1:张量 和 标量相乘
input = torch.randn(3)
other = 100
out = torch.mul(input, other)
# 等价 out = input*other
out.shape
# torch.Size([3])

# 方式2:张量 和 张量(需满足广播规则)
input = torch.randn(4, 1, 4)
other = torch.randn(2, 1)
out = torch.mul(input, other)
# 等价 out = input*other
out.shape
# torch.Size([4, 2, 4])

# 方式3:元素对应项相乘
input = torch.randn(3, 2)
other = torch.randn(3, 2)
out = torch.mul(input, other)
# 等价 out = input*other
out.shape
# torch.Size([3, 2])

4、torch.dot()

向量点积:两向量对应位置相乘然后全部相加。只能支持两个一维向量。

5、torch.mv()

矩阵和向量的乘法

  • 第一个参数只能是二维的,第二个参数是一维的,则在其维数末尾追加一个1,以实现矩阵乘法。在矩阵相乘之后,附加的维度被删除。如shape为:(2,6)和(6)运算过程:(2,6)和(6,1) -> (2,1) -> (2)
mat1 = torch.randn(6,8)
mat2 = torch.randn(8)
out = torch.mv(mat1, mat2)
out.shape
# torch.Size([6])

6、@

torch中的@操作是可以实现前面的某几个函数,是一种强大的操作。

  • 若mat1和mat2都是两个一维向量,那么对应操作就是torch.dot()
  • 若mat1是二维向量,mat2是一维向量,那么对应操作就是torch.mv()
  • 若mat1和mat2都是两个二维向量,那么对应操作就是torch.mm()

7、torch.matmul()

torch.matmul()与@操作类似,但是torch.matmul()不止局限于一维和二维,可以进行高维张量的乘法。两个张量的矩阵乘积。其行为取决于张量的维数如下:

  • 1、如果两个张量都是一维的,则返回点积(标量)。

  • 2、如果两个参数都是二维的,则返回矩阵-矩阵乘积。

  • 3、如果第一个参数是二维的,第二个参数是一维的,则在其维数末尾追加一个1,以实现矩阵乘法。在矩阵相乘之后,附加的维度被删除。如shape为:(2,6)和(6)运算过程:(2,6)和(6,1) -> (2,1) -> (2)

  • 4、如果第一个参数是一维的(则在其维数前加一个1,),第二个参数是二维的,则返回矩阵乘法。在矩阵相乘之后,附加的维度被删除。如shape为:(6)和(6,2)运算过程:(1,6)和(6,2) -> (1,2) -> (2)

  • 5、对3和4的总结。如果两个参数至少是一个参数是一维的,且至少一个参数是N维的(其中N > 2),则返回一个批处理矩阵乘法。如果第一个参数是一维的,则在其维数前加上1,以便批处理矩阵相乘,然后删除。如果第二个参数是一维的,则为批处理矩阵倍数的目的,将在其维上追加一个1,然后删除它。非矩阵(即批处理)维度是广播的(因此必须是可广播的)

  • 两个参数都是N维(>2),只有非矩阵的维度才是可以广播的,最后两维需满足矩阵乘法即m x n, n x p -> m x p。如bx1xnxm, kxmxp -> jxkxnxp

  >>> # vector x vector
  >>> tensor1 = torch.randn(3)
  >>> tensor2 = torch.randn(3)
  >>> torch.matmul(tensor1, tensor2).size()
  torch.Size([])
  >>> # matrix x vector
  >>> tensor1 = torch.randn(3, 4)
  >>> tensor2 = torch.randn(4)
  >>> torch.matmul(tensor1, tensor2).size()
  torch.Size([3])
  >>> # batched matrix x broadcasted vector
  >>> tensor1 = torch.randn(10, 3, 4)
  >>> tensor2 = torch.randn(4)
  >>> torch.matmul(tensor1, tensor2).size()
  torch.Size([10, 3])
  >>> # batched matrix x batched matrix
  >>> tensor1 = torch.randn(10, 3, 4)
  >>> tensor2 = torch.randn(10, 4, 5)
  >>> torch.matmul(tensor1, tensor2).size()
  torch.Size([10, 3, 5])
  >>> # batched matrix x broadcasted matrix
  >>> tensor1 = torch.randn(10, 3, 4)
  >>> tensor2 = torch.randn(4, 5)
  >>> torch.matmul(tensor1, tensor2).size()
  torch.Size([10, 3, 5])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

#苦行僧

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值