2. torch.chunk(tensor, chunks, dim)
说明:在给定的维度上讲张量进行分块。
参数:
- tensor(Tensor) -- 待分块的输入张量
- chunks(int) -- 分块的个数
- dim(int) -- 维度,沿着此维度进行分块
>>> x = torch.randn(3, 3)
>>> x
tensor([[ 1.0103, 2.3358, -1.9236],
[-0.3890, 0.6594, 0.6664],
[ 0.5240, -1.4193, 0.1681]])
>>> torch.chunk(x, 3, dim=0)
(tensor([[ 1.0103, 2.3358, -1.9236]]), tensor([[-0.3890, 0.6594, 0.6664]]), tensor([[ 0.5240, -1.4193, 0.1681]]))
>>> torch.chunk(x, 3, dim=1)
(tensor([[ 1.0103],
[-0.3890],
[ 0.5240]]), tensor([[ 2.3358],
[ 0.6594],
[-1.4193]]), tensor([[-1.9236],
[ 0.6664],
[ 0.1681]]))
>>> torch.chunk(x, 2, dim=1)
(tensor([[ 1.0103, 2.3358],
[-0.3890, 0.6594],
[ 0.5240, -1.4193]]), tensor([[-1.9236],
[ 0.6664],
[ 0.1681]]))