eg:
import torch as t
a=t.tensor([[[1,2,3],[4,5,6]],[[1,2,3],[4,5,6]]])
print(a.shape)
print(a)
print(t.split(a,1))
print(t.split(a,[1,1],1))
torch.Size([2, 2, 3])
tensor([[[1, 2, 3],
[4, 5, 6]],
[[1, 2, 3],
[4, 5, 6]]])
(tensor([[[1, 2, 3],
[4, 5, 6]]]), tensor([[[1, 2, 3],
[4, 5, 6]]]))
(tensor([[[1, 2, 3]],
[[1, 2, 3]]]), tensor([[[4, 5, 6]],
[[4, 5, 6]]]))
Process finished with exit code 0
附: