torch.cat()一种常见的拼接函数
不使用torch.cat()
'inputs为2个形状为[2 , 3]的矩阵 '
inputs = [x1, x2]
print(inputs)
运行一下后结果为:
[tensor([[11, 21, 31],
[21, 31, 41]], dtype=torch.int32),
tensor([[12, 22, 32],
[22, 32, 42]], dtype=torch.int32)]
使用torch.cat()后
In [1]: torch.cat(inputs, dim=0).shape
Out[1]: torch.Size([4, 3])
In [2]: torch.cat(inputs, dim=1).shape
Out[2]: torch.Size([2, 6])
In [3]: torch.cat(inputs, dim=2).shape
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)