a=torch.rand((2,2,2))
print(a)
tensor([[[0.6601, 0.5124],
[0.1438, 0.2335]],
[[0.5657, 0.0586],
[0.0366, 0.1247]]])
print(a[...,0])
tensor([[0.6601, 0.1438],
[0.5657, 0.0366]])
print(a[...,1])
tensor([[0.5124, 0.2335],
[0.0586, 0.1247]])
a=torch.rand((2,1,2))
print(a)
tensor([[[0.7589, 0.7766]],
[[0.2806, 0.7683]]])
print(a[...,0])
tensor([[0.7589],
[0.2806]])
print(a[...,1])
tensor([[0.7766],
[0.7683]])