【PyTorch】张量 (Tensor) 的拆分与拼接 (split, chunk, cat, stack)

Overview

在 PyTorch 中,对张量 (Tensor) 进行拆分通常会用到两个函数:

  • torch.split [按块大小拆分张量]
  • torch.chunk [按块数拆分张量]

而对张量 (Tensor) 进行拼接通常会用到另外两个函数:

  • torch.cat [按已有维度拼接张量]
  • torch.stack [按新维度拼接张量]

它们的作用相似,但实际效果并不完全相同,以下会通过官方文档及实例代码来进行说明,以示区别

张量 (Tensor) 的拆分

torch.split 函数

torch.split(tensor, split_size_or_sections, dim = 0)

块大小拆分张量
tensor 为待拆分张量
dim 指定张量拆分的所在维度,即在第几维对张量进行拆分
split_size_or_sections 表示在 dim 维度拆分张量时每一块在该维度的尺寸大小 (int),或各块尺寸大小的列表 (list)
指定每一块的尺寸大小后,如果在该维度无法整除,则最后一块会取余数,尺寸较小一些
如:长度为 10 的张量,按单位长度 3 拆分,则前三块长度为 3,最后一块长度为 1
函数返回:所有拆分后的张量所组成的 tuple
函数并不会改变原 tensor

torch.split 官方文档

Splits the tensor into chunks. Each chunk is a view of the original tensor.

If split_size_or_sections is an integer type, then tensor will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.

If split_size_or_sections is a list, then tensor will be split into len(split_size_or_sections) chunks with sizes in dim according to split_size_or_sections.

Parameters:

  • tensor (Tensor) – tensor to split
  • split_size_or_sections (int) or (list(int)) – size of a single chunk or list of sizes for each chunk
  • dim (int) – dimension along which to split the tensor

实例代码:

In [1]: X = torch.randn(6, 2)

In [2]: X
Out[2]:
tensor([[-0.3711,  0.7372],
        [ 0.2608, -0.1129],
        [-0.2785,  0.1560],
        [-0.7589, -0.8927],
        [ 0.1480, -0.0371],
        [-0.8387,  0.6233]])

In [3]: torch.split(X, 2, dim = 0)
Out[3]:
(tensor([[-0.3711,  0.7372],
         [ 0.2608, -0.1129]]),
 tensor([[-0.2785,  0.1560],
         [-0.7589, -0.8927]]),
 tensor([[ 0.1480, -0.0371],
         [-0.8387,  0.6233]]))

In [4]: torch.split(X, 3, dim = 0)
Out[4]:
(tensor([[-0.3711,  0.7372],
         [ 0.2608, -0.1129],
         [-0.2785,  0.1560]]),
 tensor([[-0.7589, -0.8927],
         [ 0.1480, -0.0371],
         [-0.8387,  0.6233]]))

In [5]: torch.split(X, 4, dim = 0)
Out[5]:
(tensor([[-0.3711,  0.7372],
         [ 0.2608, -0.1129],
         [-0.2785,  0.1560],
         [-0.7589, -0.8927]]),
 tensor([[ 0.1480, -0.0371],
         [-0.8387,  0.6233]]))

In [6]: torch.split(X, 1, dim = 1)
Out[6]:
(tensor([[-0.3711],
         [ 0.2608],
         [-0.2785],
         [-0.7589],
         [ 0.1480],
         [-0.8387]]),
 tensor([[ 0.7372],
         [-0.1129],
         [ 0.1560],
         [-0.8927],
         [-0.0371],
         [ 0.6233]]))

torch.chunk 函数

torch.chunk(input, chunks, dim = 0)

块数拆分张量
input 为待拆分张量
dim 指定张量拆分的所在维度,即在第几维对张量进行拆分
chunks 表示在 dim 维度拆分张量时最后所分出的总块数 (int),根据该块数进行平均拆分
指定总块数后,如果在该维度无法整除,则每块长度向上取整,最后一块会取余数,尺寸较小一些,若余数恰好为 0,则会只分出 chunks - 1
如:

  • 长度为 6 的张量,按 4 块拆分,则只分出三块,长度为 2 (6 / 4 = 1.5 → 2)
  • 长度为 10 的张量,按 4 块拆分,则前三块长度为 3 (10 / 4 = 2.5 → 3),最后一块长度为 1

函数返回:所有拆分后的张量所组成的 tuple
函数并不会改变原 input

torch.chunk 官方文档

Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor.

Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by chunks.

Parameters:

  • input (Tensor) – the tensor to split
  • chunks (int) – number of chunks to return
  • dim (int) – dimension along which to split the tensor

实例代码:

In [1]: X = torch.randn(6, 2)

In [2]: X
Out[2]:
tensor([[-0.3711,  0.7372],
        [ 0.2608, -0.1129],
        [-0.2785,  0.1560],
        [-0.7589, -0.8927],
        [ 0.1480, -0.0371],
        [-0.8387,  0.6233]])

In [3]: torch.chunk(X, 2, dim = 0)
Out[3]:
(tensor([[-0.3711,  0.7372],
         [ 0.2608, -0.1129],
         [-0.2785,  0.1560]]),
 tensor([[-0.7589, -0.8927],
         [ 0.1480, -0.0371],
         [-0.8387,  0.6233]]))

In [4]: torch.chunk(X, 3, dim = 0)
Out[4]:
(tensor([[-0.3711,  0.7372],
         [ 0.2608, -0.1129]]),
 tensor([[-0.2785,  0.1560],
         [-0.7589, -0.8927]]),
 tensor([[ 0.1480, -0.0371],
         [-0.8387,  0.6233]]))

In [5]: torch.chunk(X, 4, dim = 0)
Out[5]:
(tensor([[-0.3711,  0.7372],
         [ 0.2608, -0.1129]]),
 tensor([[-0.2785,  0.1560],
         [-0.7589, -0.8927]]),
 tensor([[ 0.1480, -0.0371],
         [-0.8387,  0.6233]]))

In [6]: Y = torch.randn(10, 2)

In [6]: Y
Out[6]:
tensor([[-0.9749,  1.3103],
        [-0.4138, -0.8369],
        [-0.1138, -1.6984],
        [ 0.7512, -0.3417],
        [-1.4575, -0.4392],
        [-0.2035, -0.2962],
        [-0.7533, -0.8294],
        [ 0.0104, -1.3582],
        [-1.5781,  0.8594],
        [ 0.0286,  0.7611]])

In [7]: torch.chunk(Y, 4, dim = 0)
Out[7]:
(tensor([[-0.9749,  1.3103],
         [-0.4138, -0.8369],
         [-0.1138, -1.6984]]),
 tensor([[ 0.7512, -0.3417],
         [-1.4575, -0.4392],
         [-0.2035, -0.2962]]),
 tensor([[-0.7533, -0.8294],
         [ 0.0104, -1.3582],
         [-1.5781,  0.8594]]),
 tensor([[0.0286, 0.7611]]))

张量 (Tensor) 的拼接

torch.cat 函数

torch.cat(tensors, dim = 0, out = None)

已有维度拼接张量
tensors 为待拼接张量的序列,通常为 tuple
dim 指定张量拼接的所在维度,即在第几维对张量进行拼接,除该拼接维度外,其余维度上待拼接张量的尺寸必须相同
out 表示在拼接张量的输出,也可直接使用函数返回值
函数返回:拼接后所得到的张量
函数并不会改变原 tensors

torch.cat 官方文档

Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.

torch.cat() can be seen as an inverse operation for torch.split() and torch.chunk().

torch.cat() can be best understood via examples.

Parameters:

  • tensors (sequence of Tensors) – any python sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension
  • dim (int, optional) – the dimension over which the tensors are concatenated
  • out (Tensor, optional) – the output tensor

实例代码:(引用自官方文档)

>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), dim = 0)
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), dim = 1)
tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497]])

torch.stack 函数

torch.stack(tensors, dim = 0, out = None)

新维度拼接张量
tensors 为待拼接张量的序列,通常为 tuple
dim 指定张量拼接的新维度对应已有维度的插入索引,即在原来第几维的位置上插入新维度对张量进行拼接,待拼接张量在所有已有维度上的尺寸必须完全相同
out 表示在拼接张量的输出,也可直接使用函数返回值
函数返回:拼接后所得到的张量
函数并不会改变原 tensors

torch.stack 官方文档

Concatenates sequence of tensors along a new dimension.

All tensors need to be of the same size.

Parameters:

  • tensors (sequence of Tensors) – sequence of tensors to concatenate
  • dim (int) – dimension to insert. Has to be between 0 and the number of dimensions of concatenated tensors (inclusive)
  • out (Tensor, optional) – the output tensor.

实例代码:

In [1]: x = torch.randn(2, 3)

In [2]: x
Out[2]:
tensor([[-0.0288,  0.6936, -0.6222],
        [ 0.8786, -1.1464, -0.6486]])

In [3]: torch.stack((x, x, x), dim = 0)
Out[3]:
tensor([[[-0.0288,  0.6936, -0.6222],
         [ 0.8786, -1.1464, -0.6486]],

        [[-0.0288,  0.6936, -0.6222],
         [ 0.8786, -1.1464, -0.6486]],

        [[-0.0288,  0.6936, -0.6222],
         [ 0.8786, -1.1464, -0.6486]]])

In [4]: torch.stack((x, x, x), dim = 0).shape
Out[4]: torch.Size([3, 2, 3])

In [5]: torch.stack((x, x, x), dim = 1)
Out[5]:
tensor([[[-0.0288,  0.6936, -0.6222],
         [-0.0288,  0.6936, -0.6222],
         [-0.0288,  0.6936, -0.6222]],

        [[ 0.8786, -1.1464, -0.6486],
         [ 0.8786, -1.1464, -0.6486],
         [ 0.8786, -1.1464, -0.6486]]])

In [6]: torch.stack((x, x, x), dim = 1).shape
Out[6]: torch.Size([2, 3, 3])

希望能够对大家有所帮助 ~

  • 25
    点赞
  • 52
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值