1. torch.cat(data,dim)
data11 = torch.randint(0,10,[2,3,2]) # torch.Size([2, 3, 2])
data12 = torch.randint(10,20,[2,3,2]) # torch.Size([2, 3, 2])
data13 = torch.cat([data11,data12],0) # 按照第0维拼接 torch.Size([4, 3, 2])
data14 = torch.cat([data11,data12],1) # 按照第1维拼接 torch.Size([2, 6, 2])
data15 = torch.cat([data11,data12],2) # 按照第2维拼接 torch.Size([2, 3, 4])
2. torch.stack(data,dim)
torch.manual_seed(1)
data21 = torch.randint(0,10,[3,3]) # torch.Size([3, 3])
'''
tensor([[5, 9, 4],
[8, 3, 3],
[1, 1, 9]])
'''
data22 = torch.randint(10,20,[3,3]) # torch.Size([3, 3])
'''
tensor([[12, 18, 19],
[16, 13, 13],
[10, 12, 11]])
'''
data23 = torch.stack([data21,data22],0) # 按照第0维拼接 torch.Size([2, 3, 3])
'''
tensor([[[ 5, 9, 4],
[ 8, 3, 3],
[ 1, 1, 9]],
[[12, 18, 19],
[16, 13, 13],
[10, 12, 11]]])
'''
data24 = torch.stack([data21,data22],1) # 按照第1维拼接 torch.Size([3, 2, 3])
'''
tensor([[[ 5, 9, 4],
[12, 18, 19]],
[[ 8, 3, 3],
[16, 13, 13]],
[[ 1, 1, 9],
[10, 12, 11]]])
'''
data25 = torch.stack([data21,data22],2) # 按照第2维拼接 torch.Size([3, 3, 2])
'''
tensor([[[ 5, 12],
[ 9, 18],
[ 4, 19]],
[[ 8, 16],
[ 3, 13],
[ 3, 13]],
[[ 1, 10],
[ 1, 12],
[ 9, 11]]])
'''