在pytorch中,常见的拼接函数主要是两个:
stack()
cat()
stack() 函数与 cat() 函数类似
1 stack() 介绍
沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状
把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠
2 参数介绍
outputs = torch.stack(inputs, dim = 0) -> Tensor
inputs : 待连接的张量序列。 注:python的序列数据只有list和tuple
新的维度, 必须在0到len(outputs)之间。 注:len(outputs)是生成数据的维度大小,也就是outputs的维度值