今天碰到一个问题,一个矩阵,一行代表一个球的特征,现在要构建两两之间的关系,也就是形成一个新的矩阵,每行矩阵把两个球的特征拼接在一起。
抽象化一下,以一共(N=2)个球为例
# 已知:
[[x],
[y]]
# 得到:
[[x, x],
[x, y],
[y, x],
[y, y]]
举个例子:
import torch
a = torch.arange(6).reshape((2,3))
a
#-----out
#tensor([[0, 1, 2],
# [3, 4, 5]])
这个矩阵就代表有两(N=2)个球,(假设第一行为x球,第二行为y球)。看到的代码是:
torch.cat((a.repeat(1, N).view(N * N, -1), a.repeat(N, 1)), dim=1)
torch.repeat()
分步骤来看:torch.repeat()
函数中的数字,代表该轴重复几次。对比以下两个步骤:
b1= a.repeat(1,N)
b1
#-----out
#tensor([[0, 1, 2, 0, 1, 2],
# [3, 4, 5, 3, 4, 5]])
c= a.repeat(N,1)
c
#-----out
#tensor([[0, 1, 2],
# [3, 4, 5],
# [0, 1, 2],
# [3, 4, 5]])
torch.view()
再来看torch.view()
,对以上b1
进行一些操作:
b = b1.view(N*N,-1) # 因为是两两相互关系,也算上自身,所以会有N*N条数据
b
#-----out
# tensor([[0, 1, 2],
# [0, 1, 2],
# [3, 4, 5],
# [3, 4, 5]])
为了更清楚 torch.view()
是怎么工作的,在这个例子上,我们做一点改变:
b2 = b1.view(3,-1)
b2
#-----out
# tensor([[0, 1, 2, 0],
# [1, 2, 3, 4],
# [5, 3, 4, 5]])
从这儿可以看出,torch.view()
在得到具体要的大小后,按照从上到下,从做到右的方法来取这一行需要有几个数字。
那么最后,对b,c
进行拼接
torch.cat((b,c),dim=1) # dim 指定拼接的轴
#-----out
# tensor([[0, 1, 2, 0, 1, 2],
# [0, 1, 2, 3, 4, 5],
# [3, 4, 5, 0, 1, 2],
# [3, 4, 5, 3, 4, 5]])
最后实现了 球之间两两拼接。
[[x, x],
[x, y],
[y, x],
[y, y]]