cat() 中第一个参数是一个需要合并的两个tensor的list,第二个参数表示在哪个维度上合并。除了要合并的维度,其他维度要相同
a = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
print(torch.cat([a, b], dim=0).shape) # torch.Size([9, 32, 8])
a1 = torch.rand(4, 3, 32, 32)
a2 = torch.rand(4, 1, 32, 32)
print(torch.cat([a1, a2], dim=0).shape) # RuntimeError: invalid argument 0 其他维度要一致
print(torch.cat([a1, a2], dim=1).shape) # torch.Size([4, 4, 32, 32])
stack() 第一个参数是一个需要合并的两个tensor的list,第二个为指定的维度,表示在指定的维度前创建一个新的维度,需要两个合并的tensor的维度完全一样
a1 = torch.rand(4, 3, 16, 32)
a2 = torch.rand(4, 3, 16, 32)
print(torch.cat([a1, a2], dim=2).shape) # torch.Size([4, 3, 32, 32])
print(torch.stack([a1, a2], dim=2).shape) # torch.Size([4, 3, 2, 16, 32])
a = torch.rand(32, 8)
b = torch.rand(32, 8)
print(torch.stack([a, b], dim=0).shape) # torch.Size([2, 32, 8])