给定一个向量a,输出其中值相同的元素的索引号
比如给定[1,1,2,3,4,5,5,5,5],其中第0,1个元素的值都是1,要输出[0,1],第5,6,7,8个元素的值都是5,要输出[5,6,7,8],如果没有相同的,就输出元素自身的索引号。
def getIdx(a):
co = a.unsqueeze(0)-a.unsqueeze(1)
uniquer = co.unique(dim=0)
out = []
for r in uniquer:
cover = torch.arange(a.size(0))
mask = r==0
idx = cover[mask]
out.append(idx)
return out
测试:
import torch
a = torch.Tensor([1,1,2,3,4,5,5,5,5])
idxs=getIdx(a)
#output: [tensor([5, 6, 7, 8]), tensor([4]), tensor([3]), tensor([2]), tensor([0, 1])]