torch.chunk(tensor, chunks, dim)
含义:在给定维度上对所给张量进行分块。
参数解释:
- Tensor -- 要进行分快的张量
- chunks -- 分块个数
- dim -- 维度,按照此维度进行分块
如下代码:
import torch
x = torch.randn(3, 3)
x
tensor([[ 0.5149, 1.0009, -0.7242],
[-0.9385, -0.1157, -1.6772],
[ 1.0187, -2.3512, 1.0539]])
torch.chunk(x, 3, dim = 0)
(tensor([[ 0.5149, 1.0009, -0.7242]]),
tensor([[-0.9385, -0.1157, -1.6772]]),
tensor([[ 1.0187, -2.3512, 1.0539]]))
torch.chunk(x, 3, dim = 1)
(tensor([[ 0.5149],
[-0.9385],
[ 1.0187]]),
tensor([[ 1.0009],
[-0.1157],
[-2.3512]]),
tensor([[-0.7242],
[-1.6772],
[ 1.0539]]))
torch.chunk(x, 2, dim = 1) # 注意当chunks为偶数时,而原始张量为所分维数为奇数时的变化
(tensor([[ 0.5149, 1.0009],
[-0.9385, -0.1157],
[ 1.0187, -2.3512]]),
tensor([[-0.7242],
[-1.6772],
[ 1.0539]]))
# 另外也可以这么使用
x.chunk(2, dim = 1)
(tensor([[ 0.5149, 1.0009],
[-0.9385, -0.1157],
[ 1.0187, -2.3512]]),
tensor([[-0.7242],
[-1.6772],
[ 1.0539]]))