torch.cat 将两个tensor拼接在一起,即拼接、联系在一起的意思;
使用torch.cat((A, B),dim)时,除拼接dim维对应的值不同外,其余维的数值需一致
A = torch.ones(2,3)
B = 2 * torch.ones(4,3)
C = torch.cat((A, B),dim=0)
print(C.shape)
D = 2 * torch.ones(2,4)
C = torch.cat((A, D),dim=1)
print(C.shape)
torch.cat 将两个tensor拼接在一起,即拼接、联系在一起的意思;
使用torch.cat((A, B),dim)时,除拼接dim维对应的值不同外,其余维的数值需一致
A = torch.ones(2,3)
B = 2 * torch.ones(4,3)
C = torch.cat((A, B),dim=0)
print(C.shape)
D = 2 * torch.ones(2,4)
C = torch.cat((A, D),dim=1)
print(C.shape)