一、问题描述
假设有三个张量a、b和c,其中张量a的shape为(8, 2),张量b的shape为(8, 5),张量c的shape为(1, 2)。张量a和张量b是对应的关系,并且张量a和张量b每行元素都唯一、不会重复。而我需要通过检索的方式来获得张量c在张量a中的索引(索引值大于等于0),如果没有那么就返回一个异常值(-1)。例如
a = [ [1,2],
[2,3],
[3,4],
[3,8],
[4,5],
[4,1],
[5,6],
[5,5] ]
b = [ [0, 0, 0, 0, 0],
[0, 0, 0, 0, 1],
[0, 0, 0, 1, 0],
[0, 0, 0, 1, 1],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 1],
[0, 0, 1, 1, 0],
[0, 0, 1, 1, 1] ]
c = [ [4,1] ]
我想要得到张量c在张量a中的索引值5,进而得到张量b中的[0, 0, 1, 0, 1]
二、解决方法
代码(主要是采取遍历的思想,因为实在是找不到太好的api)
import torch
# input_: a 2d tensor
# query_: a 2d tensor
# function: get the index of query_ in input_
def index_tensor_by_tensor(input_, query_):
# default index of result
idx = -1
# index range of tensor2d_a
n_a = input_.shape[0]
# traverse tensor2d_a
for i in range(n_a):
if input_[i][0] == query_[0][0] and input_[i][1] == query_[0][1]:
# find the query tensor
idx = i
break
return idx
# main function
if __name__ == '__main__':
# input
a = torch.tensor([[1, 2],
[2, 3],
[3, 4],
[3, 8],
[4, 5],
[4, 1],
[5, 6],
[5, 5]])
b = torch.tensor([[0, 0, 0, 0, 0],
[0, 0, 0, 0, 1],
[0, 0, 0, 1, 0],
[0, 0, 0, 1, 1],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 1],
[0, 0, 1, 1, 0],
[0, 0, 1, 1, 1]])
c_1 = torch.tensor([[4, 1]]) # it is exist in a
c_2 = torch.tensor([[7, 1]]) # it is not exist in a
# query
res_i_1 = index_tensor_by_tensor(a, c_1)
res_i_2 = index_tensor_by_tensor(a, c_2)
# output
if res_i_1 != -1:
print(res_i_1, b.select(0, res_i_1))
else:
print(res_i_1, None)
if res_i_2 != -1:
print(res_i_2, b.select(0, res_i_2))
else:
print(res_i_2, None)
效果
5 tensor([0, 0, 1, 0, 1])
-1 None