例如,将矩阵每一行的0都甩到最后面。代码的第一行和最后一行表示输入和想要的输出。
x = torch.tensor([[2,3,0,2,1,0,6,],[0,0,3,9,0,1,0]])
# [[2, 3, 0, 2, 1, 0, 6],
# [0, 0, 3, 9, 0, 1, 0]]
x_01 = (x == 0)
# [[0, 0, 1, 0, 0, 1, 0],
# [1, 1, 0, 0, 1, 0, 1]]
_, idx = x_01.sort(1)
# [[0, 1, 3, 4, 6, 2, 5],
# [2, 3, 5, 0, 1, 4, 6]]
y = x.gather(1, idx)
# [[2, 3, 2, 1, 6, 0, 0],
# [3, 9, 1, 0, 0, 0, 0]]