torch.unique()的功能类似于数学中的集合,就是挑出tensor中的独立不重复元素。
这个方法的参数在官方解释文档中有这么几个:torch.unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None)
input: 待处理的tensor
sorted:是否对返回的无重复张量按照数值进行排列,默认是生序排列的
return_inverse: 是否返回原始tensor中的每个元素在这个无重复张量中的索引
return_counts: 统计原始张量中每个独立元素的个数
dim: 值沿着哪个维度进行unique的处理,这个我试验后没有搞懂怎样的机理。如果处理的张量都是一维的,那么这个不需要理会。
下面分别对这些不同的参数进行实验讲解分析。
import torch
x = torch.tensor([4,0,1,2,1,2,3])#生成一个tensor,作为实验输入
print(x)
out = torch.unique(x) #所有参数都设置为默认的
print(out)#将处理结果打印出来
#结果如下:
#tensor([0, 1, 2, 3, 4]) #将x中的不重复元素挑了出来,并且默认为生序排列
out = torch.unique(x,sorted=False)#将默认的生序排列改为False
print(out)
#输出结果如下:
#tensor([3, 2, 1, 0, 4]) #将x中的独立元素找了出来,就按照原始顺序输出
out = torch.unique(x,return_inverse=True)#将原始数据中的每个元素在新生成的独立元素张量中的索引输出
print(out)
#输出结果如下:
#(tensor([0, 1, 2, 3, 4]), tensor([4, 0, 1, 2, 1, 2, 3])) #第一个张量是排序后输出的独立张量,第二个结果对应着原始数据中的每个元素在新的独立无重复张量中的索引,比如x[0]=4,在新的张量中的索引为4, x[1]=0,在新的张量中的索引为0,x[6]=3,在新的张量中的索引为3
out = torch.unique(x,return_counts=True) #返回每个独立元素的个数
print(out)
#输出结果如下
#(tensor([0, 1, 2, 3, 4]), tensor([1, 2, 2, 1, 1])) #0这个元素在原始数据中的数量为1,1这个元素在原始数据中的数量为2
转载自:https://blog.csdn.net/t20134297/article/details/108235355