pytorch torch.multinomial()
import torch
w = torch.arange(12,dtype=torch.float).view(3,4)
print(w)
out = torch.multinomial(w,2)
print(out)
tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]])
tensor([[2, 1],
[3, 1],
[3, 2]])