tensor索引
先上一个小例子
import torch as t
neg_mask_ints = t.arange(36).view(3, 3, 4)
target = t.randint(0, 4, (3,))
print('target: ', target)
print(neg_mask_ints)
b = neg_mask_ints[:, :, target]
print(b)
print(b.shape)
输出
target: tensor([2, 3, 1])
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]],
[[24, 25, 26, 27],
[28, 29, 30, 31],
[32, 33, 34, 35]]])
tensor([[[ 2, 3, 1],
[ 6, 7, 5],
[10, 11, 9]],
[[14, 15, 13],
[18, 19, 17],
[22, 23, 21]],
[[26, 27, 25],
[30, 31, 29],
[34, 35, 33]]])
torch.Size([3, 3, 3])
意思是说,对所有矩阵的每一行都执行列的索引抽取
tensor花式索引
import torch
neg_mask_ints = torch.randn((3, 3, 4))
batch_ints = list(range(3))
a = torch.randint(1, 3, (3, 2))
target = torch.randint(1, 4, (3, 2, 1))
print('a:', a[:, 0])
print('target: ', target)
# neg_mask_ints[batch_ints, a[:, 0], target] = 0
print(neg_mask_ints)
print('--------------------------------------')
print(neg_mask_ints[batch_ints, a[:, 0], target] )
输出
a: tensor([1, 2, 2])
target: tensor([[[3],
[2]],
[[3],
[2]],
[[2],
[1]]])
tensor([[[ 0.4134, 1.3359, -0.9656, -1.2232],
[ 0.7877, 0.8549, -0.8297, 0.5871],
[-1.6027, 0.7224, 0.9049, 0.3045]],
[[-0.4444, 0.9694, 1.1592, -0.8034],
[ 0.1240, -0.8224, -0.5136, 1.4698],
[-0.5549, 1.6872, -0.7554, 0.3377]],
[[-1.9502, -1.4392, 0.6415, 0.7268],
[ 0.1807, 0.2582, -1.3487, -1.4340],
[ 0.1910, -0.1522, 0.1951, 0.4255]]])
--------------------------------------
tensor([[[ 0.5871, 0.3377, 0.4255],
[-0.8297, -0.7554, 0.1951]],
[[ 0.5871, 0.3377, 0.4255],
[-0.8297, -0.7554, 0.1951]],
[[-0.8297, -0.7554, 0.1951],
[ 0.8549, 1.6872, -0.1522]]])
原因参考:PyTorch中使用Tensor作为索引
https://blog.csdn.net/xpy870663266/article/details/101597144
因为:x,y轴坐标先相乘,组成(0, 1), (1, 2), ( 2 ,2),再z的每个元素乘上前面的三个坐标。
也就是说,当对多维(大于1维)tensor(如A)进行索引,若A的每一维都指定一个tensor,如A(a, b, c, d),则a, b, c必须维度相同且维度小于A前三维的最小维度。d可以为1维tensor,也可以为多维,但此时d最后一维的维度必须为1。
运算时,a, b, c先按元素相乘,得到一组坐标。当d为一维时,前面的一维坐标继续乘d得到一组坐标,正常索引;当d为多维时,d的每个元素依次和这组坐标组成新坐标进行索引。
Numpy索引
import numpy as np
pr_probs_cl = np.random.randint(1, 11, (3, 2, 5))
batch_idxs = np.arange(pr_probs_cl.shape[0])
gt_classes = np.random.randint(0, 5, (3, 2, 1))
a = pr_probs_cl[batch_idxs, :, gt_classes]print(a)
print(gt_classes)
print('------------------------------')
print(pr_probs_cl)
print('------------------------------')
print(a)
print(a.shape)
输出
[[[0]
[1]]
[[3]
[2]]
[[4]
[2]]]
------------------------------
[[[ 1 6 7 4 9]
[10 1 2 1 6]]
[[ 9 10 1 8 2]
[ 6 4 9 6 1]]
[[ 9 5 6 2 8]
[ 2 9 9 6 9]]]
------------------------------
[[[[ 1 10]
[ 9 6]
[ 9 2]]
[[ 6 1]
[10 4]
[ 5 9]]]
[[[ 4 1]
[ 8 6]
[ 2 6]]
[[ 7 2]
[ 1 9]
[ 6 9]]]
[[[ 9 6]
[ 2 1]
[ 8 9]]
[[ 7 2]
[ 1 9]
[ 6 9]]]]
(3, 2, 3, 2)
(难道)numpy的花式索引和tensor一样(?)