乘法
三维矩阵乘法
最近笔者在做NLP的task,其中需要用到一个匹配两个句子之间相似度程度的技术,arxiv参考论文查阅
在这篇论文中用到以下算式实现计算两个sequence之间的相似程度,我们知道在训练过程中一般batch的形状都是三维的[B,L,E]
- B表示batch_size
- L表示sequence的长度
- E表示embedding dim
这里的F是一个feed forward network with ReLU as activation function.
那么我们如何进行三维矩阵的乘法呢?假设a,b是经过网络之后的输出,则代码示例如下:
import torch
a = torch.randn(3,4,5)
b = torch.randn(3,5,4)
c = torch.bmm(a,b)
print(c.shape)
# C: torch.Size([3, 4, 4])
Element-wise乘法
在一些论文中,还有通过Element-wise的方法体现一些举证的相关性,在此也贴出来:
拼接
在看论文的时候我们经常可以看到一些vector直接用了,
或者;
连接,这其实就是表示对vetor进行了相应的拼接操作:
,
按行拼接;
按列拼接
举一个三维矩阵的例子,那么按行拼接都是按照dim=1拼接,按列拼接就是按照dim=2拼接。pytorch的相应代码如下:
import torch
a = torch.randn(16,3,4)
b = torch.randn(16,6,4)
c = torch.randn(16,3,10)
# column-wise
out1 = torch.cat((a,c),dim=2)
# row-wise
out2 = torch.cat((a,b),dim=1)
拆分
拆分张量:torch.split()、torch.chunk()
torch.split(tensor, split_size, dim=0)
将输入张量分割成相等形状的 chunks(如果可分)。 如果沿指定维的张量形状大小不能被 split_size 整分, 则最后一个分块会小于其它分块。
举个例子:
>>> x = torch.randn(3, 10, 6)
>>> a, b, c = x.split(1, 0) # 在 0 维进行间隔维 1 的拆分
>>> a.size(), b.size(), c.size()
(torch.Size([1, 10, 6]), torch.Size([1, 10, 6]), torch.Size([1, 10, 6]))
>>> d, e = x.split(2, 0) # 在 0 维进行间隔维 2 的拆分
>>> d.size(), e.size()
(torch.Size([2, 10, 6]), torch.Size([1, 10, 6]))
把张量在 0 维度上以间隔 2 来拆分时,只能分成 2 份,且只能把前面部分先以间隔 2 来拆分,后面不足 2 的部分就直接作为一个分块。
torch.chunk(tensor, chunks, dim=0)
在给定维度(轴)上将输入张量进行分块儿
直接用上面的数据来举个例子:
>>> l, m, n = x.chunk(3, 0) # 在 0 维上拆分成 3 份
>>> l.size(), m.size(), n.size()
(torch.Size([1, 10, 6]), torch.Size([1, 10, 6]), torch.Size([1, 10, 6]))
>>> u, v = x.chunk(2, 0) # 在 0 维上拆分成 2 份
>>> u.size(), v.size()
(torch.Size([2, 10, 6]), torch.Size([1, 10, 6]))
把张量在 0 维度上拆分成 2 部分时,无法平均分配,以上面的结果来看,可以看成是,用 0 维度的尺寸除以需要拆分的份数,把余数作为最后一个分块的间隔大小,再把前面的分块以相同的间隔拆分。
在某一维度上拆分的份数不能比这一维度的尺寸大
https://zhuanlan.zhihu.com/p/100069938