pytorch tensor拼接 cat与stack
1 概述
pytorch中cat和stack都是 用于tenssor拼接的方式,但其也存在差异。
2 torch.cat()
功能说明:
将多个tensor按照指定维度进行拼接
参数说明:
tensors:将要拼接tensor按顺序写在一个元组中;
dim:在哪个维度进行拼接。总之,记得在哪个维度进行拼接,最后拼接得到的tensor的shape那个维度就会变多。见下面的例子就很容易理解。
注意:cat之后得到的结果中,dim的数量不变,之不够某个dim相加了。
举例说明:
# 下面生成一个tensor,这个tensor的shape=(2,3,4)。该tensor dim0, dim1, dim2 = 2,3,4 >>> x = torch.rand(2,3,4) >>> x tensor([[[0.7803, 0.3229, 0.1049, 0.5638], [0.4423, 0.8906, 0.7401, 0.7551], [0.7245, 0.3624, 0.2235, 0.6407]], [[0.0051, 0.4675, 0.9110, 0.9351], [0.4093, 0.8809, 0.9296, 0.6471], [0.8142, 0.6557, 0.6089, 0.1754]]]) >>> torch.cat((x,x),0) # 产生shape=(4,3,4),这是因为cat时dim参数=0,所以tensor在dim0参数相加了 tensor([[[0.7803, 0.3229, 0.1049, 0.5638], [0.4423, 0.8906,