张量的函数索引
在PyTorch中,我们还可以使用torch.index_select()函数,通过指定index来对张量进行索引。
(1)torch.index_select()函数的使用
t1 = torch.arange(1, 11)
#结果为:tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
t1.ndim
#结果为:1
indices = torch.tensor([1, 2])
#结果为:tensor([1, 2])
torch.index_select(t1, 0, indices)
#表示在t1张量中,第一个维度(行)上,查看索引为1,2的张量元素;并返回一个一维张量
#结果为:tensor([2, 3])
注:在index_select函数中,第二个参数实际上代表的是索引的维度。对于t1这个一维向量来说,由于只有一个维度,因此第二个参数取值为0,就代表在第一个维度上进行索引
t2 = torch.arange(12).reshape(4, 3)
#结果为:tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
t2.shape
#结果为:torch.Size([4, 3])
indices = torch.tensor([1, 2])
#结果为:tensor([1, 2])
torch.index_select(t2, 0, indices) #第二个参数为0,表示第一个维度(行),返回一个二维张量
#结果为:tensor([[3, 4, 5],
[6, 7, 8]])
注:dim参数取值为0,代表在shape的第一个维度(行)上索引
torch.index_select(t2, 1, indices)
#结果为:tensor([[ 1, 2],
[ 4, 5],
[ 7, 8],
[10, 11]])
注:dim参数取值为1,代表在shape的第二个维度(列)上索引
以上是本人的浅显见解,还请多多指教。