1.cat()----cat([a,b],dim=) a,b为要合并的存在list中,dim表示要合并的维度
a=torch.rand(4,3,32,32)
b=torch.rand(5,3,32,32)
print(torch.cat([a,b],dim=0).shape)
c=torch.rand(4,1,32,32)
#报错 dim=0位置相同,dim=1位置有差异,修改时要保证其他位置相同
#print(torch.cat([a,c],dim=0).shape)
print(torch.cat([a,c],dim=1).shape)
#都相同时可以改任意
a=torch.rand(4,3,16,32)
b=torch.rand(4,3,16,32)
print(torch.cat([a,b],dim=2).shape)
2.stack()----stack([a,b],dim=) a,b为要合并的存在list中,dim表示要新建维度的位置
#a,b维度必须一致
a=torch.rand(32,8)
b=torch.rand(32,8)
print(torch.stack([a,b],dim=0).shape)
#torch.Size([2,32,8]) 可见多出一个维度,为0时是a的数据,为1时是b的数据
a=torch.rand(4,3,16,32)
b=torch.rand(4,3,16,32)
print(torch.stack([a,b],dim=2).shape)
#torch.Size([4, 3, 2, 16, 32])
3.split()------split()按长度进行拆分
#split
a=torch.rand(32,8)
b=torch.rand(32,8)
#torch.Size([2, 32, 8])
c=torch.stack([a,b],dim=0)
#按照长度1在维度0开始拆分
d,e=c.split(1,dim=0)
#torch.Size([1, 32, 8]) torch.Size([1, 32, 8])
print(d.shape,e.shape)
c=torch.rand(5,32,8)
##按照长度1,2,2在维度0开始拆分
d,e,f=c.split([1,2,2],dim=0)
#torch.Size([1, 32, 8]) torch.Size([2, 32, 8]) torch.Size([2, 32, 8])
print(d.shape,e.shape,f.shape)
4.chunk()-----按数量分割
a=torch.rand(32,8)
b=torch.rand(32,8)
#torch.Size([2, 32, 8])
c=torch.stack([a,b],dim=0)
d,e=c.chunk(2,dim=0)
#torch.Size([1, 32, 8]) torch.Size([1, 32, 8])
print(d.shape,e.shape)
c=torch.rand(5,32,8)
d,e,f=c.chunk(4,dim=0)
#torch.Size([2, 32, 8]) torch.Size([2, 32, 8]) torch.Size([1, 32, 8])
print(d.shape,e.shape,f.shape)