1.合并
import torch
a=torch.randn(5,2,3,8)
b=torch.randn(4,2,3,8)
print(torch.cat([a,b],dim=0).shape)
结果:torch.Size([9, 2, 3, 8])
2.stack增加维度
import torch
a=torch.randn(4,2,3,8)
b=torch.randn(4,2,3,8)
print(torch.stack([a,b],dim=0).shape)
结果:torch.Size([2, 4, 2, 3, 8])
3.split切分by len
import torch
a=torch.randn(4,2,3,8)
a1,a2=a.split(2,dim=0)
print(a1.shape)
print(a2.shape)
a1,a2=a.split([3,1],dim=0)
print(a1.shape)
print(a2.shape)
结果:torch.Size([2, 2, 3, 8])
torch.Size([2, 2, 3, 8])
torch.Size([3, 2, 3, 8])
torch.Size([1, 2, 3, 8])
4.chunk切分by num
import torch
a=torch.randn(4,2,3,8)
a1,a2,a3,a4=a.chunk(4,dim=0)
print(a1.shape)
print(a2.shape)
print(a3.shape)
print(a4.shape)
结果:torch.Size([1, 2, 3, 8])
torch.Size([1, 2, 3, 8])
torch.Size([1, 2, 3, 8])
torch.Size([1, 2, 3, 8])