import numpy as np
import torch
a = np.array([x for x in range(12)]).reshape(3, 4)
index = np.array([[0,1] for x in range(5)]).flatten()
print "a:\n", a, a.shape
print "index:\n", index, index.shape
a:
[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]] (3, 4)
index:
[0 1 0 1 0 1 0 1 0 1] (10,)
print a[index], a[index].shape
[[0 1 2 3]
[4 5 6 7]
[0 1 2 3]
[4 5 6 7]
[0 1 2 3]
[4 5 6 7]
[0 1 2 3]
[4 5 6 7]
[0 1 2 3]
[4 5 6 7]] (10, 4)
pytorch同样有这种特性
a = torch.Tensor([x for x in range(12)]).view(3, 4)
index = torch.LongTensor([[0,1] for x in range(5)]).view(-1)
print "a:\n", a
print "index:\n", index
a:
0 1 2 3
4 5 6 7
8 9 10 11
[torch.FloatTensor of size 3x4]
index:
0
1
0
1
0
1
0
1
0
1
[torch.LongTensor of size 10]
print a[index], a[index].size()
0 1 2 3
4 5 6 7
0 1 2 3
4 5 6 7
0 1 2 3
4 5 6 7
0 1 2 3
4 5 6 7
0 1 2 3
4 5 6 7
[torch.FloatTensor of size 10x4]
torch.Size([10, 4])