参考链接: torch.cat(tensors, dim=0, out=None)
torch.cat(tensors, dim=0, out=None) → Tensor
Concatenates the given sequence of seq tensors in the given dimension.
All tensors must either have the same shape
(except in the concatenating dimension) or be empty.
torch.cat() can be seen as an inverse operation for torch.split() and torch.chunk().
torch.cat() can be best understood via examples.
Parameters
tensors (sequence of Tensors) – any python sequence of
tensors of the same type. Non-empty tensors provided
must have the same shape, except in the cat dimension.
dim (int, optional) – the dimension over which the tensors are concatenated
out (Tensor, optional) – the output tensor
在指定维度dim上拼接多个张量,
tensors是包含相同类型张量的任意python序列类型,
其中的非空张量除了拼接维度上的形状可以不同,
其他维度上的形状必须相同.
out是输出的张量
实验代码演示:
Microsoft Windows [版本 10.0.18363.1256]
(c) 2019 Microsoft Corporation。保留所有权利。
C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May 6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x0000022EA913D330>
>>> x1 = torch.randn(2, 1)
>>> x2 = torch.randn(2, 1)
>>> x3 = torch.randn(2, 1)
>>> x_vertical = torch.randn(6, 1)
>>> x_horizontal = torch.randn(2, 3)
>>> x1
tensor([[ 0.2824],
[-0.3715]])
>>> x2
tensor([[ 0.9088],
[-1.7601]])
>>> x3
tensor([[-0.1806],
[ 2.0937]])
>>> x_vertical
tensor([[ 1.0406],
[-1.7651],
[ 1.1216],
[ 0.8440],
[ 0.1783],
[ 0.6859]])
>>> x_horizontal
tensor([[-1.5942, -0.2006, -0.4050],
[-0.5556, 0.9571, 0.7435]])
>>> torch.cat((x1, x2, x3), 0)
tensor([[ 0.2824],
[-0.3715],
[ 0.9088],
[-1.7601],
[-0.1806],
[ 2.0937]])
>>> torch.cat([x1, x2, x3], 0)
tensor([[ 0.2824],
[-0.3715],
[ 0.9088],
[-1.7601],
[-0.1806],
[ 2.0937]])
>>> torch.cat((x1, x2, x3), 1)
tensor([[ 0.2824, 0.9088, -0.1806],
[-0.3715, -1.7601, 2.0937]])
>>> torch.cat([x1, x2, x3], 1)
tensor([[ 0.2824, 0.9088, -0.1806],
[-0.3715, -1.7601, 2.0937]])
>>>
>>>
>>> x_vertical
tensor([[ 1.0406],
[-1.7651],
[ 1.1216],
[ 0.8440],
[ 0.1783],
[ 0.6859]])
>>> torch.cat((x1, x2, x3), 0,out=x_vertical)
tensor([[ 0.2824],
[-0.3715],
[ 0.9088],
[-1.7601],
[-0.1806],
[ 2.0937]])
>>> x_vertical
tensor([[ 0.2824],
[-0.3715],
[ 0.9088],
[-1.7601],
[-0.1806],
[ 2.0937]])
>>> x_horizontal
tensor([[-1.5942, -0.2006, -0.4050],
[-0.5556, 0.9571, 0.7435]])
>>> torch.cat((x1, x2, x3), 1,out=x_horizontal)
tensor([[ 0.2824, 0.9088, -0.1806],
[-0.3715, -1.7601, 2.0937]])
>>> x_horizontal
tensor([[ 0.2824, 0.9088, -0.1806],
[-0.3715, -1.7601, 2.0937]])
>>>
>>>
>>>