def semantic_neighbor(x, index):
'''
假设x.shape=[B,L,C]=[2,3,4] index.shape=[B,L]=[2,3]
x = torch.tensor([
[[1, 2, 3, 4], # 样本1的3个元素,每个元素4维特征
[5, 6, 7, 8],
[9, 10, 11, 12]],
[[13, 14, 15, 16], # 样本2的3个元素
[17, 18, 19, 20],
[21, 22, 23, 24]]
])
# 索引张量 index (B=2, L=3)
index = torch.tensor([
[1, 0, 1], # 样本1的重组索引
[2, 1, 0] # 样本2的重组索引
])
'''
dim = index.dim()#dim=2
assert x.shape[:dim] == index.shape, "x ({:}) and index ({:}) shape incompatible".format(x.shape, index.shape)
for _ in range(x.dim() - index.dim()):
index = index.unsqueeze(-1)
'''
x.index=[2,3]
index = torch.tensor([
[[1],[0], [1]],
[[2], [1], [0]] ])
'''
index = index.expand(x.shape)
'''
x.index=[2,3,4]
index = torch.tensor([
[[1,1,1,1],
[0,0,0,0],
[1,1,1,1]
],
[[2,2,2,2],
[1,1,1,1],
[0,0,0,0]
] ])
'''
shuffled_x = torch.gather(x, dim=dim - 1, index=index)
'''
tensor([
[[ 5, 6, 7, 8], # 来自原始位置1
[ 1, 2, 3, 4], # 来自原始位置0
[ 5, 6, 7, 8]], # 来自原始位置1
[[21, 22, 23, 24], # 来自原始位置2
[17, 18, 19, 20], # 来自原始位置1
[13, 14, 15, 16]] # 来自原始位置0
])
'''
return shuffled_x
'''
另一个简单的示例:
源张量(3x4矩阵)
x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
索引张量(2x3矩阵)
index = torch.tensor([[0, 1, 2],
[2, 1, 0]])
沿dim=0(行方向)收集
out = torch.gather(x, dim=0, index=index)
结果:
[[1, 6, 11], # 取x[0][0], x[1][1], x[2][2]
[9, 6, 3]] # 取x[2][0], x[1][1], x[0][2]]
'''
x.sort()
x_sort_values, x_sort_indices = torch.sort(detached_index, dim=-1, stable=False)
torch.sort
:对detached_index
沿dim=-1
(即n
维度)进行排序。- 若
detached_index=[[2,0,1,0]]
那么detached_index
排序后的值是[[0, 0, 1, 2]]
(即x_sort_values
)。 x_sort_indices
是[[1, 3, 2, 0]]
,表示:- 排序后的第0个元素来自原始位置1(值是0),
- 第1个元素来自原始位置3(值是0),
- 第2个元素来自原始位置2(值是1),
- 第3个元素来自原始位置0(值是2)。