1. 定义
官方手册中描述为:
torch.cat(inputs, dimension=0) → Tensor
在给定维度上对输入的张量序列seq 进行连接操作。
torch.cat()可以看做 torch.split() 和 torch.chunk()的反操作。 cat() 函数可以通过下面例子更好的理解。
参数:
- inputs (sequence of Tensors) – 可以是任意相同Tensor 类型的python 序列
- dimension (int, optional) – 沿着此维连接张量序列。
2. 例子
>>> x = torch.randn(2, 3)
>>> x
0.5983 -0.0341 2.4918
1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 2x3]
>>> torch.cat((x, x, x), 0)
0.5983 -0.0341 2.4918
1.5981 -0.5265 -0.8735
0.5983 -0.0341 2.4918
1.5981 -0.5265 -0.8735
0.5983 -0.0341 2.4918
1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 6x3]
>>> torch.cat((x, x, x), 1)
0.5983 -0.0341 2.4918 0.5983 -0.0341 2.4918 0.5983 -0.0341 2.4918
1.5981 -0.5265 -0.8735 1.5981 -0.5265 -0.8735 1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 2x9]
torch.cat((x, x, x), 1)中的 0 or 1 就是指示的维度。
除此之外,可以指示为-1。
我将举几个例子
如图,a是2x3 b是2x5的一个张量
拼接后:
一句话总结:上下拼接要列数相同,左右拼接要行数相同。
另,用torch.cat拼接list里的tensor:
先整个list:
可以清楚的看到已经拼接好了,即参数可以直接传入一个seq