1.官方解释
- 解释
Returns the indices that sort a tensor along a given dimension in ascending order by value.
返回沿着给定维数按值升序对张量排序的索引。 - 重点
是按照值的顺序排列
2. 举例说明
a = torch.randn(4,4)
a = tensor([[ 0.0785, 1.5267, -0.8521, 0.4065],
[ 0.1598, 0.0788, -0.0745, -1.2700],
[ 1.2208, 1.0722, -0.7064, 1.2564],
[ 0.0669, -0.2318, -0.8229, -0.9280]])
b = torch.argsort(a,dim=1)
b = tensor([[2, 0, 3, 1],
[3, 2, 1, 0],
[2, 1, 0, 3],
[3, 2, 1, 0]])
我们来分析下,b 为什么是这样的。
起初我们感官的认为,当a的第一行值为 [ 0.0785, 1.5267, -0.8521, 0.4065] 的时候,我们排序应该为如下:
按照我们的感觉应该得出 b 为 [1,3,0,2]才行,但是最后输出的结果居然是[2,0,3,1];居然跟我们设想的不一样,那为啥不对呢,主要原因是我们得出来的值是按照序号排列的,而官方文档说的是按照值来排序的。
- 正确的操作:
说明:
第一步是将 a 按照顺序值进行排序得到新的序列:
[-0.8521, 0.0785, 0.4065, 1.5267]
第二步是如果才能通过序列a来得到升序的序列,
我们发现
a[2]=-0.8521,a[0]=0.0785,a[3]=0.4065,a[1]=1.5267
所以 b 返回的是序列值:[2,0,3,1];这样我们就可以通过这个序列[2,0,3,1]直接将数据按照升序进行排列了。