官方文档:torch.cat — PyTorch 2.0 documentation
torch.cat(tensors, dim=0, *, out=None) → Tensor
将给定维度中的seq张量的给定序列连接起来。所有张量必须具有相同的形状(连接维度除外)或为空。
torch.cat()可以看作是torch.split()和torch.chunk()的逆操作。
参数:
- tensors (sequence of Tensors) –任何相同类型的张量的python序列。提供的非空张量必须具有相同的形状,cat维度除外。
- dim (int, optional)–张量连接的维度,选择的扩维, 必须在
0
到len(inputs[0])
之间
import torch
x = torch.randn(2, 3)
print(x)
x.size()
torch.cat((x, x, x), 0) # 列不变,行增加
torch.cat((x, x, x), 0).size()
torch.cat((x, x, x), 1) # 行不变,列增加
torch.cat((x, x, x), 1).size()