torch.bmm()验证

官网的说明

torch.bmm(input, mat2, out=None) → Tensor

bmm的输入必须是3维的。其他维度会出错:

import torch
a = torch.Tensor(4,2,2,3)
b = torch.Tensor(4,2,3,5)
c = torch.bmm(a,b)

Traceback (most recent call last):
  File "/Users/XXX/Desktop/MyCode/xxx.py", line 1436, in <module>
    c = torch.bmm(a,b)
RuntimeError: Expected 3-dimensional tensor, but got 4-dimensional tensor for argument #1 'batch1' (while checking arguments for bmm)

下面我们演示一下bmm的使用:

import torch

a = torch.stack( [torch.ones(3,4)*torch.tensor(i+1) for i in range(5)], dim=0)
b = a.transpose(1,2)
#a.shape: (5,3,4)
#b.shape: (5,4,3)

c = torch.bmm(a,b)
#c.shape: (5,3,3)
print(c)

tensor([[[  4.,   4.,   4.],
         [  4.,   4.,   4.],
         [  4.,   4.,   4.]],

        [[ 16.,  16.,  16.],
         [ 16.,  16.,  16.],
         [ 16.,  16.,  16.]],

        [[ 36.,  36.,  36.],
         [ 36.,  36.,  36.],
         [ 36.,  36.,  36.]],

        [[ 64.,  64.,  64.],
         [ 64.,  64.,  64.],
         [ 64.,  64.,  64.]],

        [[100., 100., 100.],
         [100., 100., 100.],
         [100., 100., 100.]]])

代码中我们设置了5个3*4 的tensor stack在一起,其转置相应的是 4*3。
我们的a中的每个都是一个全1到全5的矩阵。我们知道:
I ∈ R m ∗ n I \in R^{m*n} IRmn,
I ∗ I T = n ∗ I ′ , I ′ ∈ R m ∗ m I*I^{T}=n*I',I'\in R^{m*m} IIT=nI,IRmm ,
a I ∗ a I T = n ∗ a 2 ∗ I ′ aI*aI^T=n*a^2*I' aIaIT=na2I
上述的结果正好是 a = 1 , 2 , 3 , 4 , 5 a=1,2,3,4,5 a=1,2,3,4,5的情况,因此,bmm的作用是batch号相同的两个矩阵之间的矩阵乘,不同batch号之间的矩阵无关联!

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值