关于torch.bmm()函数计算过程

很多框架中提供的矩阵乘法都是出于简化计算的考虑,很多情况下在进行计算时候都会牵扯到 batch size 这一个维度,这就使得很多矩阵的计算是三维的,Pytorch中的bmm()函数就可以很方便的实现三维数组的乘法,而不用拆成二维数组使用for循环解决。在查资料的时候发现有些博客写的有些小地方不太对,而且有很多提问都是关于 bmm()函数具体是如何计算的,因此记录。

1.torch.bmm()

函数定义:

def bmm(self: Tensor,
        mat2: Tensor,
        *,
        out: Optional[Tensor] = None) -> Tensor

函数的传入参数很简单,两个三维矩阵而已,只是要注意这两个矩阵的shape有一些要求:

res = torch.bmm(ma, mb)
ma: [a, b, c]
mb: [a, c, d]

也就是说两个tensor的第一维是相等的,然后第一个数组的第三维和第二个数组的第二维度要求一样,对于剩下的则不做要求,其实这里的意思已经很明白了,两个三维矩阵的乘法其实就是保持第一维度不变,每次相当于一个切片做二维矩阵的乘法,对于上面的矩阵来说,就是 for i in range(a) 然后 ma[i] * mb[i],这是一个熟悉的二维矩阵乘法,两个矩阵的shape分别是[b, c][c, d]。因此,输出的结果的shape也很明显了:[a, b, d]。下面验证一下:

2.验证

首先创建两个tensor:

a = torch.linspace(1, 24, 24).view(2, 3, 4)  # shape [2, 3, 4]

b = torch.linspace(1, 16, 16).view(2, 4, 2)   # shape [2, 4, 2]

两个tensor分别是:

tensor([[[ 1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.],
         [ 9., 10., 11., 12.]],

        [[13., 14., 15., 16.],
         [17., 18., 19., 20.],
         [21., 22., 23., 24.]]])
tensor([[[ 1.,  2.],
         [ 3.,  4.],
         [ 5.,  6.],
         [ 7.,  8.]],

        [[ 9., 10.],
         [11., 12.],
         [13., 14.],
         [15., 16.]]])

接下来分别使用bmm函数和for循环方式实现乘法:

c = torch.bmm(a, b)
print(c)
d = np.array([torch.mm(a[i], b[i]).numpy() for i in range(len(a))])
print(d)

输出分别是:

tensor([[[  50.,   60.],
         [ 114.,  140.],
         [ 178.,  220.]],

        [[ 706.,  764.],
         [ 898.,  972.],
         [1090., 1180.]]])
[[[  50.   60.]
  [ 114.  140.]
  [ 178.  220.]]

 [[ 706.  764.]
  [ 898.  972.]
  [1090. 1180.]]]

也可以使用函数检查一下:

print((d == c.numpy()).all())
输出:True

3.更实际一点的想法

就像刚才所说的那样,只要根据实际的情况考虑一下,这个函数的计算过程很好理解,由于 batch size的引入,所以处理数据的时候很容易出现三维数组,例如处理文本计算attention权重的时候,很容易得到的权重矩阵shape是 [batch_size, sequence_length],然后需要相乘的隐状态矩阵是 [batch_size, sequence_length, hidden_size]。按照attention的计算方式,实际上就是权重矩阵中每一行的数值分别乘以隐状态矩阵中每一行的对应位置的隐状态,这个过程当然可以写循环,也可以简单的使用bmm函数计算,先将权重矩阵reshape成 [batch_size, 1, sequence_length]然后bmm(weigths_matrix, hidden_matrix)然后得到的结果就是attention计算的结果了。

  • 54
    点赞
  • 137
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值