chunk可以对张量进行分块,返回一个张量列表
torch.chunk(tensor, chunks, dim=0) → List of Tensors
Splits a tensor into a specific number of chunks.
Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by chunks.(如果指定轴的元素个数被chunks除不尽,那么最后一块的元素个数变少)
parameters:
- tensor : 输入的张量
- chunks : 分割的块数
- dim : 沿着哪个轴分割
import torch
arr = torch.tensor(np.arange(24).reshape((6, 2, 2)))
x1, x2 = arr.chunk(chunks=2, dim=0)
x1, x2 = torch.chunk(arr, chunks=2, dim=0) # 等价于上行代码
>>>
tensor([[[ 0, 1],
[ 2, 3]],
[[ 4, 5],
[ 6, 7]],
[[ 8, 9],
[10, 11]]], dtype=torch.int32)
tensor([[[12, 13],
[14, 15]],
[[16, 17],
[18, 19]],
[[20, 21],
[22, 23]]], dtype=torch.int32)
x1, x2, x3 = arr.chunk(chunks=4, dim=0) # 不能整除
tensor([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]], dtype=torch.int32)
tensor([[[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15]]], dtype=torch.int32)
tensor([[[16, 17],
[18, 19]],
[[20, 21],
[22, 23]]], dtype=torch.int32)