一、合并
1. cat 函数
- 规则:
- 所合并的数据的dim一致
- 要合并的维度上shape可以不一致,其余的shape必须一致
- 理解:[class,student]合并=>[class,student],合并后的班级在含以上相同
- 例子
a = torch.rand(4,3,16,32)
b = torch.rand(4,3,16,32)
print(torch.cat([a,b],dim=2).shape)
2.stack函数
- 规则:
- 要合并的两个维度必须一致
- 会在合并的维度前插入一个新的维度
- 理解:[class,student]合并=>[class_id,class,student],相当于合并后两个班级分开,意义上不同,比如dim=0维度上的班级是尖子班,dim=1维度上的班级是普通班。
- 例子
print(' stack ')
print(torch.stack([a,b],dim=2).shape)
a = torch.rand(3,5)
b = torch.rand(3,5)
print(torch.stack([a,b],dim=0).shape)
二、分割
1. split函数
- 规则:
- 给定在某一维度拆分后长度
- 给定在某一维度拆分后的每个长度
- 例子
a = torch.rand(32,8)
b = torch.rand(32,8)
c = torch.stack([a,b],dim=0)
aa,bb = c.split([1,1],dim=0)
print('aa.shape=',aa.shape,' bb.shape=',bb.shape)
aa,bb = c.split(1,dim=0)
print('aa.shape=',aa.shape,' bb.shape=',bb.shape)
c = torch.rand(9,4)
aa,bb,cc = c.split(3,dim=0)
print(aa.shape,bb.shape,cc.shape)
aa,bb,cc = c.split([2,3,4],dim=0)
print(aa.shape,bb.shape,cc.shape)
2. chuck函数
- 规则
- 参数:要分割的数量(不足的向上取整),和维度
- 例子
a = torch.rand(4,5,2)
aa,bb = a.chunk(2,dim=0)
print('aa.shape=',aa.shape,' bb.shape=',bb.shape)
aa,bb = a.chunk(2,dim=1)
print('aa.shape=',aa.shape,' bb.shape=',bb.shape)