拼接
第一种方法 cat
在指定维度上拼接
t = torch.ones((3,2))
res = torch.cat([t, t], dim=0)
res
按行拼接,会生成一个(6,2)的张量
如果dim=1就是按列拼接,生成(3,4)的张量
如果是三维张量:
t = torch.ones((2,3,2))
res = torch.cat([t, t], dim=0)
res
这个时候dim=0是针对第三维的个数, dim=1,2才分别是行,列
第二种方法stack
在新创建的维度上拼接
t = torch.ones((3,2))
res = torch.stack([t, t], dim=0)
res
这里的dim=0会导致结果变成(x,3,2)的形式,这个x就等于stack中有几个t,这里就是(2,3,2)
同理,dim=1会生成(3,2,2)
切分
第一种split
在某个维度上进行切分
t = torch.ones((3,5))
list_of_tensor = torch.split(t, [2,1,2], dim=1)
for i in list_of_tensor:
print(i)
比如这里就注明了在dim=1的维度上切分成三块,分别是2,1,2的大小,注意这里切分出来的总和要正好等于这个维度,即2+1+2=5,否则会报错
当然也可以直接注明每一份的大小
这里就规定每一份都是2,仍然在,那么就会出现2,2,1的结果
list_of_tensor = torch.split(t, 2, dim=1)
for i in list_of_tensor:
print(i)
第二种chunk
这种方法规定chunks, 即要切分的份数
假设总数7份,要切出3份,是不能整除的,就会出现3,3,1的结果,同理总数7份,切成4份,也不能整除,就会出现2,2,2,1的结果
t = torch.ones((3,7))
list_of_tensor = torch.chunk(t, chunks = 4, dim=1)
for i in list_of_tensor:
print(i)