torch.chunk:用来将tensor分成很多个块,简而言之我理解的就是切分吧,可以在不同维度上切分。
torch.chunk(tensor,chunk数,维度)
代码示例:
import torch
a=torch.tensor([[[1,2],[3,4]],
[[5,6],[7,8]]])
b=torch.chunk(a,2,1)
print(a)
print(b)
输出:
tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
(tensor([[[1, 2]],
[[5, 6]]]),
tensor([[[3, 4]],
[[7, 8]]]))