pytorch, numpy利用另一个数组的索引进行重排序,如torch.gather.......

15 篇文章 0 订阅
3 篇文章 0 订阅

构造不同维度上的索引进行重组

直接构造不同维度索引的方法在pytorch和numpy是类似的,具体如下:

1、numpy
x = np.random.randint(0, 10, [2, 3])
print(x)  
 结果:
     [[1 5 1] 
      [9 6 4]]
indices = np.argsort(x, dim=1) # 在第1维度上进行排序
print(indices)
结果:
array([[0, 2, 1],
       [2, 1, 0]])

first_indices = np.arange(2).reshape(2, -1)  # 构造第0维度上的索引
print(first_indices)
结果:
array([[0],
       [1]])
或者使用first_indices = np.arange(indices.shape[0)[:, None], indices]效果一样
result = x[first_indices, indices]
print(result)
结果:
array([[1, 1, 5],
         [4, 6, 9]])
2、pytorch
y = torch.from_numpy(x)
indices = torch.argsort(y, dim=1)
print(indices)
结果:
array([[0, 2, 1],
       [2, 1, 0]])
       
first_indices = torch.arange(indices.shape[0])[:, None]
print(first_indices)
结果:
array([[0],
       [1]])     
result = y[first_indices, indices]
print(result)
结果:
array([[1, 1, 5],
       [4, 6, 9]])

pytorch另外一种方法:torch.gather()

result = y.gather(dim=1, index = indices)
print(result)
结果:
array([[1, 1, 5],
       [4, 6, 9]])

再举一个例子

import torch
torch.manual_seed(1)

logits = torch.rand(3, 10)
label_ids = torch.rand(3, 10)
print(logits)
print(label_ids)
sorted_logits, indices = logits.sort(dim=1)
new_label_ids = label_ids.gather(dim=1, index=indices)
print(indices)
print(new_label_ids)

结果:

tensor([[0.7576, 0.2793, 0.4031, 0.7347, 0.0293, 0.7999, 0.3971, 0.7544, 0.5695,
         0.4388],
        [0.6387, 0.5247, 0.6826, 0.3051, 0.4635, 0.4550, 0.5725, 0.4980, 0.9371,
         0.6556],
        [0.3138, 0.1980, 0.4162, 0.2843, 0.3398, 0.5239, 0.7981, 0.7718, 0.0112,
         0.8100]])
tensor([[0.6397, 0.9743, 0.8300, 0.0444, 0.0246, 0.2588, 0.9391, 0.4167, 0.7140,
         0.2676],
        [0.9906, 0.2885, 0.8750, 0.5059, 0.2366, 0.7570, 0.2346, 0.6471, 0.3556,
         0.4452],
        [0.0193, 0.2616, 0.7713, 0.3785, 0.9980, 0.9008, 0.4766, 0.1663, 0.8045,
         0.6552]])
tensor([[4, 1, 6, 2, 9, 8, 3, 7, 0, 5],
        [3, 5, 4, 7, 1, 6, 0, 9, 2, 8],
        [8, 1, 3, 0, 4, 2, 5, 7, 6, 9]])
tensor([[0.0246, 0.9743, 0.9391, 0.8300, 0.2676, 0.7140, 0.0444, 0.4167, 0.6397,
         0.2588],
        [0.5059, 0.7570, 0.2366, 0.6471, 0.2885, 0.2346, 0.9906, 0.4452, 0.8750,
         0.3556],
        [0.8045, 0.2616, 0.3785, 0.0193, 0.9980, 0.7713, 0.9008, 0.1663, 0.4766,
         0.6552]])

从上述结果可以看到,label_ids按照logits的排序进行的重排序。这个在机器学习训练模型对数据进行shuffle,或者对多个多维数据进行排序时经常用到。

  • 5
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值