1 torch.sort()

torch.sort(input, dim=-1, descending=False, stable=False, *, out=None)

1.1 作用

根据给定的维度对输入张量进行升值或降值排序。

1.2 参数

input: 需要是一个torch.Tensor类型的张量。

dim: 给定一个张量的维度(int型),按照这个维度上的数值进行排序。如果不指定,默认按照张量的最后一个维度进行排序。

descending:传入一个布尔类型的数据(Ture、False),True代表降值排序,False代表升值排序。如果不指定,默认升值排序。

stable:传入一个布尔类型的数据(Ture、False),当一个张量中存在多个相同数字时,例如[2, 2, 1, 1],传入True不会打乱同一个数字的先后顺序(第一个1会排在第一个,第二个1会排在第二个)。如果不指定,默认False。

out:(Tensor, LongTensor) 的输出元组,可以选择用作输出缓冲区。如果不指定,默认None。

1.3 举例

先是只传入张量,其他参数均为默认:

import torch

tensor_a = torch.tensor([[2, 1],
                         [3, 4],
                         [6, 5]])
sorted_tensor_a, indices = torch.sort(tensor_a)
print(sorted_tensor_a, '\n', indices)

#---------输出---------#
tensor([[1, 2],
        [3, 4],
        [5, 6]]) 
 tensor([[1, 0],
        [0, 1],
        [1, 0]])
#----------------------#

 dim = 0 的情况:

import torch

tensor_a = torch.tensor([[6, 1],
                         [1, 4],
                         [2, 5]])
sorted_tensor_a, indices = torch.sort(tensor_a, dim=0)
print(sorted_tensor_a, '\n', indices)

#---------输出---------#
tensor([[1, 1],
        [2, 4],
        [6, 5]]) 
 tensor([[1, 0],
        [2, 1],
        [0, 2]])
#----------------------#

descending = True 的情况:

import tensor

tensor_a = torch.tensor([[6, 1],
                         [1, 4],
                         [2, 5]])
sorted_tensor_a, indices = torch.sort(tensor_a, dim=0, descending=True)
print(sorted_tensor_a, '\n', indices)

#---------输出---------#
tensor([[6, 5],
        [2, 4],
        [1, 1]]) 
 tensor([[0, 2],
        [2, 1],
        [1, 0]])
#----------------------#

stable =  True 的情况:

import torch

tensor_a = torch.tensor([0, 1] * 9)

sorted_tensor_a, indices = torch.sort(tensor_a, stable=True)
print(sorted_tensor_a, '\n', indices)
#---------------------------输出---------------------------#
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]) 
 tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16,  1,  3,  5,  7,  9, 11, 13, 15, 17])
#----------------------------------------------------------#

sorted_tensor_a, indices = torch.sort(tensor_a)
print(sorted_tensor_a, '\n', indices)
#---------------------------输出---------------------------#
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]) 
 tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16,  1,  3,  5,  7,  9, 11, 13, 15, 17])
#----------------------------------------------------------#

可以看到,在我的运行结果中,stable无论是True还是False,好像结果都是一样的,但是以下是官方教程中的例子:

 同样的函数,同样的输入,为什么我的和官方的输出不一样我也不是很清楚。。。

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值