torch.cat()
-
根据给定的维度,将tensor序列连接起来
- tensor序列中的tensor,除了连接维度,其他维度的形状必须相同
- 指定的维度必须在tensor的维度范围内。连接后得到维度不变,形状不同的tensor
- tensor序列可以表示为列表类型 / 元组类型
- tensor的维度不等于tensor的形状。维度为3的tensor的形状可以是[2, 3, 4],也可以是[5, 6, 7]。
-
示例:
从示例中可以看出,t1, t2, t3 的维度为3,连接后得到的t4, t5的维度也为3,形状为[2, 7, 5]。
参考链接: torch.cat()
torch.stack()
-
作用:在新的维度上连接tensor序列
- tensor序列中的tensor形状要完全相同,即不但要维度相同,形状也要一样
- 得到的tensor在指定维度增加一个维度,增加的维度的形状为tensor序列的长度
-
示例
t4在第1维增加一个维度,原来的第1维变成第2维,t5在最后增加一个维度。
torch.cat()是在tensor的原维度进行操作,而torch.stack()操作会改变tensor的维度,增加一个维度。
参考链接: torch.stack()