场景
给定一个三维矩阵x(batch, seq_len, input_size)
,最后需要得到一个拼接矩阵e(batch_size, seq_len, seq_len, 2*input_size)
,例如e[0, 1, 2, :]=x[0, 1, :] || x[0, 2, :]
,其中||表示拼接。
实现
x = torch.randn((64, 24, 7))
seq_len, size = x.shape[1], x.shape[2]
x1 = x.repeat(1, 1, seq_len).view(x.shape[0], seq_len * seq_len, -1)
x2 = x.repeat(1, seq_len, 1)
cat_x = torch.cat([x1, x2], dim=-1).view(x.shape[0], seq_len, -1, 2 * size)
其中x1 = x.repeat(1, 1, seq_len).view(x.shape[0], seq_len * seq_len, -1)
,x.repeat(1, 1, seq_len)
表示将第2维度重复了seq_len
次,从(64, 24, 7)
变成(64, 24, 24*7)
,接着view(x.shape[0], seq_len * seq_len, -1)
变成(64, 24 * 24, 7)
,相当于11...1122...22...
。同理x2(64, 24*24, 7)
,x2
的重复形式是1,2,…,24,1,…,24…。接着,我们将x1
和x2
进行拼接,变成(64, 24*24, 14)
,这里相当于是24个1首先依次和1…24拼接,然后是24个2个1…24拼接,目的达成。验证:
print(cat_x.shape)
print(x[0, 1, :])
print(x[0, 2, :])
print(cat_x[0, 1, 2, :])
torch.Size([64, 24, 24, 14])
tensor([-1.2668, -1.1904, 0.8832, 2.0187, 0.3969, 0.1294, 0.5685])
tensor([-0.8796, 0.0320, 1.2344, 1.0180, 0.2738, 0.1357, -0.4144])
tensor([-1.2668, -1.1904, 0.8832, 2.0187, 0.3969, 0.1294, 0.5685, -0.8796,
0.0320, 1.2344, 1.0180, 0.2738, 0.1357, -0.4144])
可以发现cat_x[0, 1, 2, :]
确实是x[0, 1, :]
和x[0, 2, :]
的拼接形式。