1.用法介绍
pytorch中的torch.sort()是按照指定的维度对tensor张量的元素进行排序。
torch.sort(input, dim=-1, descending=False, stable=False, * ,out =None)
输入:
- input(tensor): 表示输入的张量
- dim(int, optional):表示张量元素排序的维度
- descending(bool, optional):0表示升序,1表示降序
- stable(bool, optional):保持张量中相等元素的顺序
输出:
- 第一个tensor表示对应维度张量值排序;第二个tensor表示对应维度张量值得索引序列
2. 代码实例
torch.sort()的代码实例如下所示:
import torch
torch.manual_seed(2)
Tensor = torch.rand((3,3))
print(Tensor)
print(torch.sort(Tensor,0)) # print(Tensor.sort(0))
print(torch.sort(Tensor,1)) # print(Tensor.sort(1))
运行的实验结果如下所示:
tensor([[0.6147, 0.3810, 0.6371],
[0.4745, 0.7136, 0.6190],
[0.4425, 0.0958, 0.6142]])
torch.return_types.sort(
values=tensor([[0.4425, 0.0958, 0.6142],
[0.4745, 0.3810, 0.6190],
[0.6147, 0.7136, 0.6371]]),
indices=tensor([[2, 2, 2],
[1, 0, 1],
[0, 1, 0]]))
torch.return_types.sort(
values=tensor([[0.3810, 0.6147, 0.6371],
[0.4745, 0.6190, 0.7136],
[0.0958, 0.4425, 0.6142]]),
indices=tensor([[1, 0, 2],
[0, 2, 1],
[1, 0, 2]]))