tensor和numpy的花式索引

tensor索引

先上一个小例子

import torch as t

neg_mask_ints = t.arange(36).view(3, 3, 4)
target = t.randint(0, 4, (3,))
print('target: ', target)
print(neg_mask_ints)
b = neg_mask_ints[:, :, target]
print(b)
print(b.shape)

输出

target:  tensor([2, 3, 1])
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]],

        [[24, 25, 26, 27],
         [28, 29, 30, 31],
         [32, 33, 34, 35]]])
tensor([[[ 2,  3,  1],
         [ 6,  7,  5],
         [10, 11,  9]],

        [[14, 15, 13],
         [18, 19, 17],
         [22, 23, 21]],

        [[26, 27, 25],
         [30, 31, 29],
         [34, 35, 33]]])
torch.Size([3, 3, 3])

意思是说,对所有矩阵的每一行都执行列的索引抽取

tensor花式索引

import torch 

neg_mask_ints = torch.randn((3, 3, 4))
batch_ints = list(range(3))
a = torch.randint(1, 3, (3, 2))
target = torch.randint(1, 4, (3, 2, 1))
print('a:', a[:, 0])
print('target: ', target)
# neg_mask_ints[batch_ints, a[:, 0], target] = 0
print(neg_mask_ints)
print('--------------------------------------')
print(neg_mask_ints[batch_ints, a[:, 0], target] )

输出

a: tensor([1, 2, 2])
target:  tensor([[[3],
         [2]],

        [[3],
         [2]],

        [[2],
         [1]]])
tensor([[[ 0.4134,  1.3359, -0.9656, -1.2232],
         [ 0.7877,  0.8549, -0.8297,  0.5871],
         [-1.6027,  0.7224,  0.9049,  0.3045]],

        [[-0.4444,  0.9694,  1.1592, -0.8034],
         [ 0.1240, -0.8224, -0.5136,  1.4698],
         [-0.5549,  1.6872, -0.7554,  0.3377]],

        [[-1.9502, -1.4392,  0.6415,  0.7268],
         [ 0.1807,  0.2582, -1.3487, -1.4340],
         [ 0.1910, -0.1522,  0.1951,  0.4255]]])
--------------------------------------
tensor([[[ 0.5871,  0.3377,  0.4255],
         [-0.8297, -0.7554,  0.1951]],

        [[ 0.5871,  0.3377,  0.4255],
         [-0.8297, -0.7554,  0.1951]],

        [[-0.8297, -0.7554,  0.1951],
         [ 0.8549,  1.6872, -0.1522]]])

原因参考:PyTorch中使用Tensor作为索引

https://blog.csdn.net/xpy870663266/article/details/101597144

因为:x,y轴坐标先相乘,组成(0, 1), (1, 2), ( 2 ,2),再z的每个元素乘上前面的三个坐标。
也就是说,当对多维(大于1维)tensor(如A)进行索引,若A的每一维都指定一个tensor,如A(a, b, c, d),则a, b, c必须维度相同且维度小于A前三维的最小维度。d可以为1维tensor,也可以为多维,但此时d最后一维的维度必须为1。
运算时,a, b, c先按元素相乘,得到一组坐标。当d为一维时,前面的一维坐标继续乘d得到一组坐标,正常索引;当d为多维时,d的每个元素依次和这组坐标组成新坐标进行索引。

在这里插入图片描述

Numpy索引

import numpy as np
pr_probs_cl = np.random.randint(1, 11, (3, 2, 5))
batch_idxs = np.arange(pr_probs_cl.shape[0])
gt_classes = np.random.randint(0, 5, (3, 2, 1))
a = pr_probs_cl[batch_idxs, :, gt_classes]print(a)
print(gt_classes)
print('------------------------------')
print(pr_probs_cl)
print('------------------------------')
print(a)
print(a.shape)

输出

[[[0]
  [1]]

 [[3]
  [2]]

 [[4]
  [2]]]
------------------------------
[[[ 1  6  7  4  9]
  [10  1  2  1  6]]

 [[ 9 10  1  8  2]
  [ 6  4  9  6  1]]

 [[ 9  5  6  2  8]
  [ 2  9  9  6  9]]]
------------------------------
[[[[ 1 10]
   [ 9  6]
   [ 9  2]]

  [[ 6  1]
   [10  4]
   [ 5  9]]]


 [[[ 4  1]
   [ 8  6]
   [ 2  6]]

  [[ 7  2]
   [ 1  9]
   [ 6  9]]]


 [[[ 9  6]
   [ 2  1]
   [ 8  9]]

  [[ 7  2]
   [ 1  9]
   [ 6  9]]]]
(3, 2, 3, 2)

(难道)numpy的花式索引和tensor一样(?)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值