前言
这篇作为自己读论文代码过程中一些简单的代码备忘读记吧,方便查阅。
一、torch.cat() 、torch.stack()
拼接张量:torch.cat() 、torch.stack()
- torch.cat(inputs, dimension=0) → Tensor
在给定维度上对输入的张量序列 seq 进行连接操作
举个例子:
>>> ``import` `torch``
>>> x ``=` `torch.randn(``2``, ``3``)``
>>> x``tensor([[``-``0.1997``, ``-``0.6900``, ``0.7039``],`` ``[ ``0.0268``, ``-``1.0140``, ``-``2.9764``]])``
>>> torch.cat((x, x, x), ``0``) ``
# 在 0 维(纵向)进行拼接``tensor([[``-``0.1997``, ``-``0.6900``, ``0.7039``],`` ``[ ``0.0268``, ``-``1.0140``, ``-``2.9764``],`` ``[``-``0.1997``, ``-``0.6900``, ``0.7039``],`` ``[ ``0.0268``, ``-``1.0140``, ``-``2.9764``],`` ``[``-``0.1997``, ``-``0.6900``, ``0.7039``],`` ``[ ``0.0268``, ``-``1.0140``, ``-``2.9764``]])``
>>> torch.cat((x, x, x), ``1``) ``
# 在 1 维(横向)进行拼接``tensor([[``-``0.1997``, ``-``0.6900``, ``0.7039``, ``-``0.1997``, ``-``0.6900``, ``0.7039``, ``-``0.1997``, ``-``0.6900``,`` ``0.7039``],`` ``[ ``0.0268``, ``-``1.0140``, ``-``2.9764``, ``0.0268``, ``-``1.0140``, ``-``2.9764``, ``0.0268``, ``-``1.0140``,`` ``-``2.9764``]])``
>>> y1 ``=` `torch.randn(``5``, ``3``, ``6``)``
>>> y2 ``=` `torch.randn(``5``, ``3``, ``6``)``
>>>> torch.cat([y1, y2], ``2``).size()``torch.Size([``5``, ``3``, ``12``])``
>>> torch.cat([y1, y2], ``1``).size()``torch.Size([``5``, ``6``, ``6``])
对于需要拼接的张量,维度数量必须相同,进行拼接的维度的尺寸可以不同,但是其它维度的尺寸必须相同。
- torch.stack(sequence, dim=0)
沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状
举个例子:
>>> x1 ``=` `torch.randn(``2``, ``3``)``
>>>> x2 ``=` `torch.randn(``2``, ``3``)``
>>>> torch.stack((x1, x2), ``0``).size() ``
># 在 0 维插入一个维度,进行区分拼接``torch.Size([``2``, ``2``, ``3``])``
>>>> torch.stack((x1, x2), ``1``).size() `
># 在 1 维插入一个维度,进行组合拼接``
>torch.Size([``2``, ``2``, ``3``])``
>>>> torch.stack((x1, x2), ``2``).size()``torch.Size([``2``, ``3``, ``2``])``
>>>> torch.stack((x1, x2), ``0``)``tensor([[[``-``0.3499``, ``-``0.6124``, ``1.4332``],`` ``[ ``0.1516``, ``-``1.5439``, ``-``0.1758``]],` ` ``[[``-``0.4678``, ``-``1.1430``, ``-``0.5279``],`` ``[``-``0.4917``, ``-``0.6504``, ``2.2512``]]])``>>> torch.stack((x1, x2), ``1``)``tensor([[[``-``0.3499``, ``-``0.6124``, ``1.4332``],`` ``[``-``0.4678``, ``-``1.1430``, ``-``0.5279``]],` ` ``[[ ``0.1516``, ``-``1.5439``, ``-``0.1758``],`` ``[``-``0.4917``, ``-``0.6504``, ``2.2512``]]])`&