cat将按照dim维拼接,即cat后的维度dim维是所有矩阵dim维的和,其余的维度不变,见例子:
x1 = torch.arange(1, 7).view([2, 3])
x2 = torch.arange(101, 107).view([2, 3])
x3 = torch.arange(201, 207).view([2, 3])
x4 = torch.cat((x1,x2,x3), dim=0)
print(x4.shape)
print(x4)
x4 = torch.cat((x1,x2,x3), dim=1)
print(x4.shape)
print(x4)
结果:
torch.Size([6, 3])
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[101, 102, 103],
[104, 105, 106],
[201, 202, 203],
[204, 205, 206]])
torch.Size([2, 9])
tensor([[ 1, 2, 3, 101, 102, 103, 201, 202, 203],
[ 4, 5, 6, 104, 105, 106, 204, 205, 206]])