torch.cat
-
定义:
torch.cat(tensors, dim=0, out=None)
→ Tensor -
参数:
tensors
(Sequence of Tensors):要连接的张量序列。dim
(int, 可选):沿着此维连接张量序列。当dim=0
时,torch.cat()
会按行连接多个张量,也就是在第一个维度上进行连接。这将导致张量在垂直方向上叠加。当dim=1
时,torch.cat()
会按列连接多个张量,也就是在第二个维度上进行连接。这将导致张量在水平方向上叠加。out
(Tensor, 可选):输出张量。
-
返回值:
- 一个新的张量,它是输入张量在指定维度上的连接。
-
用途:
torch.cat
用于将给定维度上的一系列张量连接在一起。张量在除连接维以外的所有维度上必须具有相同的形状。
torch.stack
-
定义:
torch.stack(tensors, dim=0, out=None)
→ Tensor -
参数:
tensors
(Sequence of Tensors):要堆叠的张量序列,所有张量都应有相同的形状。dim
(int, 可选):插入新维度的索引。out
(Tensor, 可选):输出张量。
-
返回值:
- 一个新的张量,它沿着新维度对输入张量序列进行堆叠。
-
用途:
torch.stack
用于创建一个新的维度,并在该维度上堆叠一系列张量。与torch.cat
不同,torch.stack
会增加一个新的维度,所以输出张量的维度会比输入张量多一个。