torch.gather(tensor,待填充的索引tensor,dim):在原tensor中沿dim方向,在每单位数组中取出对应索引的元素。
tensor = torch.tensor([[1, 10, 100],
[2, 20, 200]])
"""
以dim=0为例,沿tensor的行方向看去,三单位数组分别是
[1, [10, [100,
2] 20] 200]
"""
index_tensor = torch.LongTensor([[1]])
结果是 tensor([[2]])
index_tensor = torch.LongTensor([[1,0]])
结果是 tensor([[2, 10]])
index_tensor = torch.LongTensor([[1,0,0]])
结果是 tensor([[2, 10, 100]])
index_tensor = torch.LongTensor([[1,0,0], [1,0,0]])
结果是 tensor([[2, 10, 100], [2, 10, 100]])
index_tensor = torch.LongTensor([[1,0,0], [1,0,0], ..., [1,0,0]])
结果是 tensor([[2, 10, 100], [2, 10, 100], ...., [2, 10, 100]])
tensor = torch.tensor([[[1, 10, 100],
[6, 60, 600]],
[[2, 11, 101],
[7, 61, 601]]])
"""
以dim=0为例, 沿上下堆叠的方向, 2x3个单位数组:
[[1, [10, [100,
2], 11], 101]],
[[6, [60, [600,
7], 61], 601]]
"""
index_tensor = torch.LongTensor([[[0, 1, 0]]])
结果是tensor([[[ 1, 11, 100]]])
index_tensor = torch.LongTensor([[[0, 1, 0], [0, 1, 0]]])
结果是tensor([[[ 1, 11, 100],
[ 6, 61, 601]]])
index_tensor = torch.LongTensor([[[0, 0, 0], [0, 0, 0]]])
结果是tensor([[[ 1, 10, 100],
[ 6, 60, 600]]])
index_tensor = torch.LongTensor([[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], ...., [[0, 0, 0], [0, 0, 0]]])
结果是tensor([[[ 1, 10, 100],
[ 6, 60, 600]],
[[ 1, 10, 100],
[ 6, 60, 600]],
...
...
[[ 1, 10, 100],
[ 6, 60, 600]]])
""