深度学习初探/02-Pytorch知识/06-拼接与拆分
一、合并
1、cat
场景:两个老师分别统计班级成绩,现需将两位老师统计的数据进行合并。
A老师统计class 1-4 ⇒ \Rightarrow ⇒ [class1-4, students, scores];
B老师统计class 5-9 ⇒ \Rightarrow ⇒ [class 5-9, students, scores]。
a = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
# 将a和b在dimension0合并
cb = torch.cat([a, b], dim=0)
print(cb.shape)
Out:
torch.Size([9, 32, 8])
2、stack
使用stack时,2个tensor的维度必须完全一样
会在合并的dim前插入一个新的dim,用于选择(比如:新dim值=0时,选择原a的数据;新dim值=1时,选择原b的数据)
a = torch.rand(4, 32, 8)
b = torch.rand(4, 32, 8)
# 将a和b在dimension0合并
cb = torch.stack([a, b], dim=0)
print(cb.shape)
Out: torch.Size([2, 4, 32, 8])
二、拆分
1、split 根据长度进行拆分
c = torch.rand([2, 32, 8])
#1 直接给出拆分的目标长度配额
aa, bb = c.split([1, 1], dim=0)
print(aa.shape, bb.shape)
Out: torch.Size([1, 32, 8]) torch.Size([2, 32, 8])
#2 按某种长度进行平均拆分
aa, bb = c.split(2, dim=0)
print(aa.shape, bb.shape)
Out:torch.Size([2, 32, 8]) torch.Size([2, 32, 8])
2、chunk 根据数量进行拆分(即分成多少份)
c = torch.rand([4, 32, 8])
aa, bb = c.chunk(2, dim=0)
print(aa.shape, bb.shape)
Out: torch.Size([2, 32, 8]) torch.Size([2, 32, 8])