torch.gather()和torch.sort

torch.gather()

def semantic_neighbor(x, index):
'''
假设x.shape=[B,L,C]=[2,3,4]   index.shape=[B,L]=[2,3]
x = torch.tensor([
    [[1, 2, 3, 4],    # 样本1的3个元素,每个元素4维特征
     [5, 6, 7, 8],
     [9, 10, 11, 12]],
    
    [[13, 14, 15, 16], # 样本2的3个元素
     [17, 18, 19, 20],
     [21, 22, 23, 24]]
])

# 索引张量 index (B=2, L=3)
index = torch.tensor([
    [1, 0, 1],  # 样本1的重组索引
    [2, 1, 0]   # 样本2的重组索引
])

'''
    dim = index.dim()#dim=2
    assert x.shape[:dim] == index.shape, "x ({:}) and index ({:}) shape incompatible".format(x.shape, index.shape)
    for _ in range(x.dim() - index.dim()):
        index = index.unsqueeze(-1)
        '''
        x.index=[2,3]
        index = torch.tensor([
    	[[1],[0], [1]], 
    	[[2], [1], [0]]  ])
        '''
    index = index.expand(x.shape)
          '''
        x.index=[2,3,4]
        index = torch.tensor([
    	[[1,1,1,1],
    	[0,0,0,0], 
    	[1,1,1,1]
    	], 
    	[[2,2,2,2], 
    	[1,1,1,1], 
    	[0,0,0,0]
    	]  ])
        '''
    shuffled_x = torch.gather(x, dim=dim - 1, index=index)
    '''
    tensor([
    [[ 5,  6,  7,  8],  # 来自原始位置1
     [ 1,  2,  3,  4],  # 来自原始位置0
     [ 5,  6,  7,  8]], # 来自原始位置1
     
    [[21, 22, 23, 24],  # 来自原始位置2
     [17, 18, 19, 20],  # 来自原始位置1
     [13, 14, 15, 16]]  # 来自原始位置0
])
    '''
    return shuffled_x


'''
另一个简单的示例:
源张量(3x4矩阵)
x = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12]])

索引张量(2x3矩阵)
index = torch.tensor([[0, 1, 2],
                      [2, 1, 0]])

沿dim=0(行方向)收集
out = torch.gather(x, dim=0, index=index)


结果:
[[1,  6, 11],  # 取x[0][0], x[1][1], x[2][2]
 [9,  6,  3]]  # 取x[2][0], x[1][1], x[0][2]]
'''

x.sort()
x_sort_values, x_sort_indices = torch.sort(detached_index, dim=-1, stable=False)

  • torch.sort:对 detached_index 沿 dim=-1(即 n 维度)进行排序。
  • detached_index=[[2,0,1,0]]那么detached_index 排序后的值是 [[0, 0, 1, 2]](即 x_sort_values)。
  • x_sort_indices[[1, 3, 2, 0]],表示:
    • 排序后的第0个元素来自原始位置1(值是0),
    • 第1个元素来自原始位置3(值是0),
    • 第2个元素来自原始位置2(值是1),
    • 第3个元素来自原始位置0(值是2)。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值