1、cat函数
a = torch.rand(4,32,8)
b = torch.rand(5,32,8)
print(torch.cat([a,b],dim=0).shape) # torch.Size([9, 32, 8])
cat函数的第一个参数为需要合并的两个张量,第二个参数dim表示哪一维需要合并,如上式表示第0维需要合并,注:除了需合并的维数上的数目可以不同,其他维数需相同
2、stack函数
a = torch.rand(32,8)
b = torch.rand(32,8)
print(torch.stack([a,b],dim=0).shape) #torch.Size([2, 32, 8])
stack函数第一个参数表示需要合并的两个张量,第二个参数表示需要合并的维度,如上式表示,将两个32*8的张量合并成为一个2*32*8的张量,而使用cat函数合并之后的结果为64*8
3、split函数
a = torch.rand(4,3,5)
a1,a2 = a.split(2,dim=0)
print(a1.shape,a2.shape) #torch.Size([2, 3, 5]) torch.Size([2, 3, 5])
split函数表示对张量进行拆分,第一个参数表示拆分时的步长,第二个参数表示需要对第0维进行拆分,注:第一个步长可以为张量[3,2,1],表示按照步长3,2,1拆分张量
4、chunk函数
a = torch.rand(4,3,5)
a1,a2 = a.chunk(2,dim=0)
print(a1.shape,a2.shape) #torch.Size([2, 3, 5]) torch.Size([2, 3, 5])
chunk函数的第一个参数表示将目标维度分为几块,如上式2就表示分为两块,第二个参数为需拆分的维度