tensor1[tensor2]
刚看到这个结构有点懵,不知道它是具体怎么工作的
example.py
a = torch.arange(16)
b = torch.tensor([2,2,0,1,0,0,1,0,2,1,0,0,1,0,0,0],dtype=torch.uint8)
print(a)
print(b)
print(a[b])
index_list = [[4,3,2,1,0]]
c = torch.LongTensor(index_list)
# print(a)
print(a[c])
print(a.shape,c.shape,a[c].shape)
d = []
for i,index in enumerate(index_list):
d.append(a[index])
print(d)
'''output
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
tensor([2, 2, 0, 1, 0, 0, 1, 0, 2, 1, 0, 0, 1, 0, 0, 0], dtype=torch.uint8)
tensor([ 0, 1, 3, 6, 8, 9, 12])
tensor([[4, 3, 2, 1, 0]])
torch.Size([16]) torch.Size([1, 5]) torch.Size([1, 5])
[tensor([4, 3, 2, 1, 0])]
'''
索引为torch.uint8类型
可以看到在tensor2
为bool/uint8
类型时,tensor2 更像是一个mask,将原有tensor进行筛选一遍,取出tensor2 对应位置不为0的元素
索引为torch.long类型
这个时候就比较麻烦了,tensor2
中存的更像是tensor1
中的位置id, 这个时候a[b].shape == b.shape
相当于在 tensor2 中将所有的元素替换成tensor1中指定位置的元素,写了一个替代脚本:
a = torch.arange(16)
index_list = [[4,3,2,1,0]]
c = torch.LongTensor(index_list)
print(a[c])
d = []
for i,index in enumerate(index_list):
d.append(a[index])
print(d)
# a[c] == d
## 多维的tensor
a = torch.arange(12).view(4,3)
print(a[c])
print(a.shape,c.shape,a[c].shape)
d = []
for i,index in enumerate(index_list):
d.append(a[index])
print(d)
'''
tensor([[[ 6, 7, 8],
[ 9, 10, 11],
[ 6, 7, 8],
[ 3, 4, 5],
[ 0, 1, 2]]])
torch.Size([4, 3]) torch.Size([1, 5]) torch.Size([1, 5, 3])
[tensor([[ 6, 7, 8],
[ 9, 10, 11],
[ 6, 7, 8],
[ 3, 4, 5],
[ 0, 1, 2]])]
'''