pytorch矩阵乘法mm,bmm

矩阵维度

首先需要确认多维矩阵每个维度的对应含义。

a = torch.tensor([[[3.], [1.]], [[4.], [0.]], [[2.], [1.]]])

这是一个三维矩阵,他的size是:3,2,1
很容易理解和记忆,最外层[]内的元素数即size的第一个参数,第二层对应第二个参数,以此类推。
那么其中元素对应关系就有了:
a 000 = 3 a_{000}=3 a000=3 a 210 = 1 a_{210}=1 a210=1
整个矩阵就有:
[ [ 3 1 ] [ 4 0 ] [ 2 1 ] ] \begin{bmatrix} \begin{bmatrix}3 \\ 1\end{bmatrix} \begin{bmatrix}4 \\ 0\end{bmatrix} \begin{bmatrix}2 \\ 1\end{bmatrix} \end{bmatrix} [[31][40][21]]
正常来讲,最内侧的两维可以用矩阵行列表示,外面就只能通过嵌套矩阵表示了。
最内侧为行,次内侧为列

矩阵乘法

pytorch的矩阵乘法中,主要内容为2维×2维。也就是用torch.mm这个函数。

torch.mm

直接上例子:

import torch
b = torch.tensor([[2., 3.]])
c = torch.tensor([[3.], [1.]])
print(b)
print(c)
out1 = torch.mm(b, c)
print(out1)
# ### output
tensor([[2., 3.]])
tensor([[3.],
        [1.]])
tensor([[9.]])

[ 2   3 ] × [ 3 1 ] = [ 9 ] \begin{bmatrix}2 \ 3 \end{bmatrix} ×\begin{bmatrix}3 \\ 1\end{bmatrix}=[9] [2 3]×[31]=[9]
先理清矩阵维度,这个就很容易理解了。

torch.bmm

矩阵批处理乘法。怎么命名不重要,功能就是两批2维矩阵对应相乘。

a = torch.tensor([[[2., 3.], [1., 2.]], [[3., 4.], [0., 5.]]])
w = torch.tensor([[[3.], [1.]], [[2.], [4.]]])
print(a)
print(w)
out = torch.bmm(a, w)
print(out)

tensor([[[2., 3.],
         [1., 2.]],

        [[3., 4.],
         [0., 5.]]])
tensor([[[3.],
         [1.]],

        [[2.],
         [4.]]])
tensor([[[ 9.],
         [ 5.]],

        [[22.],
         [20.]]])

torch.matmul

这个我只知道在二维时与mm结果一样,高维的更复杂,但是没有深入研究。
在这里插入图片描述

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值