pytorch tensor矩阵索引矩阵
正确性有待考究。。。
含有 : 索引
import torch as t
neg_mask_ints = t.arange(36).view(3, 3, 4)
target = t.randint(0, 4, (3,))
print('target: ', target)
print(neg_mask_ints)
b = neg_mask_ints[:, :, target]
print(b)
print(b.shape)
输出
target: tensor([2, 3, 1])
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]],
[[24, 25, 26, 27],
[28, 29, 30, 31],
[32, 33, 34, 35]]])
tensor([[[ 2, 3, 1],
[ 6, 7, 5],
[10, 11, 9]],
[[14, 15, 13],
[18, 19, 17],
[22, 23, 21]],
[[26, 27, 25],
[30, 31, 29],
[34, 35, 33]]])
torch.Size([3, 3, 3])
意思是说,对所有矩阵的每一行都执行列的索引抽取
含有tensor的索引
impo