【学习记录】torch.chunk()函数用法

torch.chunk 是 PyTorch 中一个用于将张量沿指定维度分割成多个子张量的函数。它对于处理需要均匀分割的张量非常有用。以下是 torch.chunk 的详细介绍,包括其语法、参数、返回值和示例。

1 语法

torch.chunk(input, chunks, dim=0)

参数:

  • input (Tensor): 要分割的输入张量。
  • chunks (int): 要将张量分割成的块数。如果不能均匀分割,最后一个块将会比其他块小,剩多少有多少。
  • dim (int): 指定要沿着哪个维度进行分割。默认值为 0,即沿着第一个维度分割。

返回值:

  • output (list of Tensors): 一个包含分割后子张量的列表。列表的长度等于 chunks 参数的值。

示例一:假设我们有一个形状为 (10, 5) 的张量,并且我们希望将其沿第一个维度分割成 5 个块。每个块将包含 2 行数据。

import torch

# 创建一个形状为 (10, 5) 的张量
x = torch.arange(50).reshape(10, 5)

# 将张量沿第一个维度分割成 5 个块
chunks = torch.chunk(x, 5, dim=0)

# 打印每个子张量
for i, chunk in enumerate(chunks):
    print(f"Chunk {i}:\n{chunk}\n")

输出结果:

Chunk 0:
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9]])

Chunk 1:
tensor([[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]])

Chunk 2:
tensor([[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29]])

Chunk 3:
tensor([[30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]])

Chunk 4:
tensor([[40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49]])

处理不能均匀分割的情况

当张量不能被均匀分割时,最后一个子张量将包含剩余的元素。

# 创建一个形状为 (7, 3) 的张量
y = torch.arange(21).reshape(7, 3)

# 将张量沿第一个维度分割成 4 个块
chunks = torch.chunk(y, 4, dim=0)

# 打印每个子张量
for i, chunk in enumerate(chunks):
    print(f"Chunk {i}:\n{chunk}\n")

输出:

Chunk 0:
tensor([[0, 1, 2],
        [3, 4, 5]])

Chunk 1:
tensor([[6, 7, 8],
        [9, 10, 11]])

Chunk 2:
tensor([[12, 13, 14],
        [15, 16, 17]])

Chunk 3:
tensor([[18, 19, 20]])


在这个例子中,前 3 个子张量每个包含 2 行数据,最后一个子张量包含剩余的 1 行数据。

用法示例:处理模型预测输出

在深度学习中,我们可能需要将模型的预测输出分割成不同的部分,以便进一步处理。例如:

import torch

# 模拟模型的预测输出
pred = {
    'loc_row': torch.randn(40, 5),
    'loc_col': torch.randn(40, 5),
    'exist_row': torch.randn(40, 5),
    'exist_col': torch.randn(40, 5)
}

# 将预测结果按行和列分割成多个部分
loc_row, loc_row_left, loc_row_right, _, _ = torch.chunk(pred['loc_row'], 5, dim=1)
loc_col, _, _, loc_col_up, loc_col_down = torch.chunk(pred['loc_col'], 5, dim=1)
exist_row, exist_row_left, exist_row_right, _, _ = torch.chunk(pred['exist_row'], 5, dim=1)
exist_col, _, _, exist_col_up, exist_col_down = torch.chunk(pred['exist_col'], 5, dim=1)

# 打印分割后的张量形状
print(loc_row.shape)        # torch.Size([40, 1])
print(loc_row_left.shape)   # torch.Size([40, 1])
print(loc_row_right.shape)  # torch.Size([40, 1])
print(loc_col_up.shape)     # torch.Size([40, 1])
print(loc_col_down.shape)   # torch.Size([40, 1])
print(exist_row.shape)      # torch.Size([40, 1])
print(exist_row_left.shape) # torch.Size([40, 1])
print(exist_row_right.shape)# torch.Size([40, 1])
print(exist_col_up.shape)   # torch.Size([40, 1])
print(exist_col_down.shape) # torch.Size([40, 1])

torch.chunk 是一个非常有用的函数,特别是在需要将张量均匀分割成多个子张量的情况下。通过指定分割的维度和块数,我们可以方便地处理各种张量分割任务。在深度学习中,它常用于处理模型的输出,将复杂的预测结果分解为更易处理的部分。
可以用来抽出张量中需要的部分。

  • 4
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

超好的小白

没体验过打赏,能让我体验一次吗

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值