torch.chunk 是 PyTorch 中一个用于将张量沿指定维度分割成多个子张量的函数。它对于处理需要均匀分割的张量非常有用。以下是 torch.chunk 的详细介绍,包括其语法、参数、返回值和示例。
1 语法
torch.chunk(input, chunks, dim=0)
参数:
- input (Tensor): 要分割的输入张量。
- chunks (int): 要将张量分割成的块数。如果不能均匀分割,最后一个块将会比其他块小,剩多少有多少。
- dim (int): 指定要沿着哪个维度进行分割。默认值为 0,即沿着第一个维度分割。
返回值:
- output (list of Tensors): 一个包含分割后子张量的列表。列表的长度等于 chunks 参数的值。
示例一:假设我们有一个形状为 (10, 5) 的张量,并且我们希望将其沿第一个维度分割成 5 个块。每个块将包含 2 行数据。
import torch
# 创建一个形状为 (10, 5) 的张量
x = torch.arange(50).reshape(10, 5)
# 将张量沿第一个维度分割成 5 个块
chunks = torch.chunk(x, 5, dim=0)
# 打印每个子张量
for i, chunk in enumerate(chunks):
print(f"Chunk {
i}:\n{
chunk}\n")
输出结果:
Chunk 0:
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9]])
Chunk 1:
tensor(