pytorch中的torch.cat()和torch.chunk()

本文介绍了PyTorch中处理张量的两个重要函数:torch.cat()用于沿指定维度拼接多个张量,而torch.chunk()则将一个张量按维度分割为多个子张量。通过实例展示了如何使用这两个函数进行张量的组合与拆分,对于理解和操作PyTorch张量具有指导意义。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

用法介绍

pytorch中张量进行拼接和分割的函数分别是torch.cat()torch.chunk()torch.cat()是将多个张量组成的元组按照指定的维度进行拼接。torch.chunk()是对一个张量按照某个维度分割成多个子张量块。它们具体的用法如下所示

torch.cat(tensors, dim=0, *, out=None)⟶\longrightarrowTensor

  • tensors (tuple of tensor):张量组成的元组
  • dim (int):按照某个维度对多个张量进行拼接

注意: 如果多个张量按照某个维度进行拼接,那么其它的维度要一致。

torch.chunk(input, chunks, dim=0)⟶\longrightarrowList of Tensors

  • input (Tensor):要被分割的张量
  • chunks (int):被分割的张量数
  • dim (int):按照某个维度对张量进行分割

代码示例

torch.cat()的代码示例如下所示

>>> import torch
>>> x = torch.randn(2,3)
>>> x
tensor([[-0.5654,  0.7048,  0.5851],
        [-1.3871,  0.5481,  0.3028]])
>>> torch.cat((x,x,x),0)
tensor([[-0.5654,  0.7048,  0.5851],
        [-1.3871,  0.5481,  0.3028],
        [-0.5654,  0.7048,  0.5851],
        [-1.3871,  0.5481,  0.3028],
        [-0.5654,  0.7048,  0.5851],
        [-1.3871,  0.5481,  0.3028]])
>>> torch.cat((x,x,x),1)
tensor([[-0.5654,  0.7048,  0.5851, -0.5654,  0.7048,  0.5851, -0.5654,  0.7048,
          0.5851],
        [-1.3871,  0.5481,  0.3028, -1.3871,  0.5481,  0.3028, -1.3871,  0.5481,
          0.3028]])

torch.chunk()的代码示例如下所示

>>> import torch
>>> x = torch.randn(8,8)
>>> x
tensor([[ 1.0272,  1.5964,  0.1502,  1.3435, -0.1774,  0.7908,  0.6920,  1.0908],
        [ 0.8614, -0.3212,  0.4715,  0.1476,  1.7950,  1.8308, -0.1419, -0.1448],
        [-0.7407,  0.5510,  0.1284,  0.1485,  0.2997, -0.8133,  1.5608,  0.0682],
        [ 0.7217,  0.5292,  0.2469,  0.1823, -0.6200,  0.9436, -0.5221, -0.9343],
        [-2.0195, -2.3613, -0.6441, -1.7863,  1.4207,  0.4124,  0.5508, -0.2569],
        [ 0.4582, -1.6445, -0.6813, -0.8802,  0.9870, -0.6599, -0.4719,  0.3088],
        [-1.6415, -0.9834,  0.1687,  0.0159,  0.4456, -0.1823,  0.9652, -0.2785],
        [ 0.8765,  0.8214,  1.0971, -0.4150, -0.9499, -0.5875, -1.3902, -0.9129]])
>>> x.chunk(chunks=2, dim=0)
(tensor([[ 1.0272,  1.5964,  0.1502,  1.3435, -0.1774,  0.7908,  0.6920,  1.0908],
        [ 0.8614, -0.3212,  0.4715,  0.1476,  1.7950,  1.8308, -0.1419, -0.1448],
        [-0.7407,  0.5510,  0.1284,  0.1485,  0.2997, -0.8133,  1.5608,  0.0682],
        [ 0.7217,  0.5292,  0.2469,  0.1823, -0.6200,  0.9436, -0.5221, -0.9343]]), 
 tensor([[-2.0195, -2.3613, -0.6441, -1.7863,  1.4207,  0.4124,  0.5508, -0.2569],
        [ 0.4582, -1.6445, -0.6813, -0.8802,  0.9870, -0.6599, -0.4719,  0.3088],
        [-1.6415, -0.9834,  0.1687,  0.0159,  0.4456, -0.1823,  0.9652, -0.2785],
        [ 0.8765,  0.8214,  1.0971, -0.4150, -0.9499, -0.5875, -1.3902, -0.9129]]))
>>> x.chunk(chunks=2, dim=1)
(tensor([[ 1.0272,  1.5964,  0.1502,  1.3435],
        [ 0.8614, -0.3212,  0.4715,  0.1476],
        [-0.7407,  0.5510,  0.1284,  0.1485],
        [ 0.7217,  0.5292,  0.2469,  0.1823],
        [-2.0195, -2.3613, -0.6441, -1.7863],
        [ 0.4582, -1.6445, -0.6813, -0.8802],
        [-1.6415, -0.9834,  0.1687,  0.0159],
        [ 0.8765,  0.8214,  1.0971, -0.4150]]), 
 tensor([[-0.1774,  0.7908,  0.6920,  1.0908],
        [ 1.7950,  1.8308, -0.1419, -0.1448],
        [ 0.2997, -0.8133,  1.5608,  0.0682],
        [-0.6200,  0.9436, -0.5221, -0.9343],
        [ 1.4207,  0.4124,  0.5508, -0.2569],
        [ 0.9870, -0.6599, -0.4719,  0.3088],
        [ 0.4456, -0.1823,  0.9652, -0.2785],
        [-0.9499, -0.5875, -1.3902, -0.9129]]))
### PyTorch `split` `chunk` 函数的区别 #### 功能描述 PyTorch 的 `split` `chunk` 都用于分割张量,但两者的工作方式有所不同。 对于 `torch.split(tensor, split_size_or_sections, dim=0)` 函数而言,此方法允许指定要切割成的小片段大小或是各个部分的具体尺寸列表。如果提供的是单个整数,则表示每一部分的长度;如果是元组或列表,则精确指定了各分片的大小[^1]。 ```python import torch tensor_example = torch.arange(8) # 使用 split 方法按每份两个元素来划分张量 result_split = torch.split(tensor_example, 2) print([t.tolist() for t in result_split]) ``` 另一方面,`torch.chunk(tensor, chunks, dim=0)` 则更倾向于平均分配输入张量到指定数量的部分中去。当无法均匀切分时,最后一块可能会小于其他部分。这里只接受一个参数来定义希望得到多少个子集。 ```python # 使用 chunk 方法将张量分为三等分 result_chunk = torch.chunk(tensor_example, 3) print([t.tolist() for t in result_chunk]) ``` 这两种操作都支持通过设置维度参数 (`dim`) 来控制沿哪个轴执行拆分,默认情况下是在第零维上工作。 #### 输出差异展示 上述代码会分别打印出由 `split` `chunk` 返回的结果: - 对于 `split`: 如果传入了合适的总等于原始张量长度的分割尺寸数组,那么将会获得完全按照给定规格被分开的新张量集合; - 而对于 `chunk`: 不管怎样都会尝试尽可能公平地把原数据分成所请求的数量级,即使这意味着某些部分可能比其他的稍大一些或小一点。 因此,在实际编程过程中可以根据具体需求选择合适的方法来进行张量的操作处理。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

道2024

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值