一个关于pytorch的tensor点乘的小问题

事情的缘由是,坐旁边的学姐有一段代码将两个二维数组相乘。特别的是,既不是点乘,也不是矩阵乘法,而是将各自每一行分别相乘再拼接得到一个三维数组,具体代码大致如下

import torch
a = torch.Tensor(range(6)).reshape(2, 3)
b = torch.Tensor(range(1, 7)).reshape(2, 3)
batch = len(a)
length = len(a[0])
c = torch.zeros(batch, batch, length)
start = time.time()
for i in range(batch):
    for j in range(batch):
        c[i][j] = a[i] * b[j]

可以看到,实际上是使用a的每一行乘以b的每一行作为c的一个元素,最终c是一个三维数组。

但问题在于,python解释器的执行速度较慢,因此,我做了改进,将a和b分别按照维度1和维度进行扩展,再相乘,结果是一样的。

a_batch = torch.stack([a] * batch, dim=1)
b_batch = torch.stack([b] * batch, dim=0)
c_batch = a_batch * b_batch

速度显然快了很多,但是比较占内存且麻烦。回头看了一个学弟的代码的解决方案,使用了None扩展便捷地解决了这个问题。

a_2 = a[:, None, :]
b_2 = b[None, :, :]
c_2 = a_2 * b_2

实际上,a_2和b_2的维度大小是在None那一维度为1而不是我stack那样的数个。输出各自的维度如下

a.shape  torch.Size([2, 3])
a_batch.shape  torch.Size([2, 2, 3])
a_2.shape  torch.Size([2, 1, 3])
b.shape  torch.Size([2, 3])
b_batch.shape  torch.Size([2, 2, 3])
b_2.shape  torch.Size([1, 2, 3])

疑惑是,之前看的csdn博客,都说过pytorch的矩阵点乘需要两个矩阵的维度相同,然而a_2和b_2为何维度不同也能相乘呢?因此去查询pytorch的官方文档。即查看torch.mul的文档https://pytorch.org/docs/stable/torch.html?highlight=mul#torch.mul,可见

可见,其实在维度不相同时,如果矩阵是可广播的也可以相乘,查看broadcastable的定义 https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics

中文翻译即为

如果以下规则成立,则两个张量是“可广播的”:

  1. 每个张量至少有一个维度
  2. 当迭代尺寸时,从尾部尺寸开始,尺寸必须相等,或者其中一个尺寸为1,或者尺寸不存在

官方举例

>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# same shapes are always broadcastable (i.e. the above rules always hold)

>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# x and y are not broadcastable, because x does not have at least 1 dimension

# can line up trailing dimensions
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(  3,1,1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist

# but:
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(  3,1,1)
# x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3

所以,还是要多查看官方文档,博客上很多是不全面的。

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值