pytorch矩阵中每行(列)乘上不同元素以及矩阵相乘

torch.mul和*等价(attetion中可以用到)
每行乘上不同元素

>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
>>> b = torch.Tensor([1,2,3]).reshape((3,1))
>>> b
tensor([[1.],
        [2.],
        [3.]])
>>> a * b
tensor([[1., 1., 1., 1.],
        [2., 2., 2., 2.],
        [3., 3., 3., 3.]])

每列乘上不同元素

>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
>>> b = torch.Tensor([1,2,3,4]).reshape((1,4))
>>> b
tensor([[1., 2., 3., 4.]])
>>> a*b
tensor([[1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.]])
>>> torch.mul(a,b)
tensor([[1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.]])

带batch(mul和*会自动broadcaset到所以batch)

>>> a=torch.ones(2,3,4)
>>> a
tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]])
>>> b = torch.Tensor([1,2,3,4]).reshape((1,4))
>>> b
tensor([[1., 2., 3., 4.]])
>>> a*b
tensor([[[1., 2., 3., 4.],
         [1., 2., 3., 4.],
         [1., 2., 3., 4.]],

        [[1., 2., 3., 4.],
         [1., 2., 3., 4.],
         [1., 2., 3., 4.]]])
>>> torch.mul(a,b)
tensor([[[1., 2., 3., 4.],
         [1., 2., 3., 4.],
         [1., 2., 3., 4.]],

        [[1., 2., 3., 4.],
         [1., 2., 3., 4.],
         [1., 2., 3., 4.]]])

此外还有针对矩阵的乘法如:torch.dot() torch.mm() torch.bmm() torch.dot()是针对一维的向量进行点积。

In [252]: a=torch.randn(2,3)

In [253]: b=torch.randn(3,2)

In [254]: torch.dot(a,b)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-254-4939a8ae602a> in <module>()
----> 1 torch.dot(a,b)

RuntimeError: dot: Expected 1-D argument self, but got 2-D

In [255]: a=torch.randn(3)

In [256]: b=torch.randn(3)

In [257]: torch.dot(a,b)
Out[257]: tensor(0.0967)

torch.mm是针对矩阵的点积(只针对2维)

In [258]: a=torch.randn(2,3)

In [259]: b=torch.randn(3,2)

In [260]: torch.mm(a,b)
Out[260]: 
tensor([[-1.2849,  0.1272],
        [ 0.0600, -0.3183]])

In [261]: torch.mm(a,b).size()
Out[261]: torch.Size([2, 2])

torch.bmm()是针对一个batch的二维矩阵进行点积(假设batch_size=2)

In [262]: a=torch.randn(2,2,3)

In [263]: b=torch.randn(2,3,2)

In [264]: torch.mm(a,b).size()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-264-6bdb1a27d804> in <module>()
----> 1 torch.mm(a,b).size()

RuntimeError: matrices expected, got 3D, 3D tensors at /Users/soumith/code/builder/wheel/pytorch-src/aten/src/TH/generic/THTensorMath.cpp:2065

In [265]: torch.bmm(a,b).size()
Out[265]: torch.Size([2, 2, 2])

那么如果维度大于3,我们要针对数据集中单个的矩阵进行点积怎么办呢?比如在多头attention中,batch=2,head=2
torch.matmul只针对最后俩个维度进行点积。

In [266]: b=torch.randn(2,2,3,2)

In [267]: a=torch.randn(2,2,2,3)

In [268]: torch.bmm(a,b).size()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-268-3084f0e99edf> in <module>()
----> 1 torch.bmm(a,b).size()

RuntimeError: invalid argument 1: expected 3D tensor, got 4D at /Users/soumith/code/builder/wheel/pytorch-src/aten/src/TH/generic/THTensorMath.cpp:2304

In [269]: torch.matmul(a,b).size()
Out[269]: torch.Size([2, 2, 2, 2])

In [270]: a=torch.randn(2,2,3)

In [271]: b=torch.randn(2,3,2)

In [272]: torch.matmul(a,b).size()
Out[272]: torch.Size([2, 2, 2])

可以matmul同样可以实现bmm

  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值