import torch
# 数据
a = [[1, 1, 2.5],
[2.8, 7.5, 5.4],
[2.5, 9.2, 3.8],
[2.5, 3.2, 3.5]]
c = torch.tensor(a)
'''
torch.sort(input, dim=-1, descending=False)
input:输入数据
dim:在哪个维度排序
descending:升序还是降序。Ture为降序,默认为False
'''
X = torch.sort(c,1,descending=False) # 返回两个值,第一个为排序后数据,第二个为索引
print(X)
# 输出
torch.return_types.sort(
values=tensor([[1.0000, 1.0000, 2.5000],
[2.8000, 5.4000, 7.5000],
[2.5000, 3.8000, 9.2000],
[2.5000, 3.2000, 3.5000]]),
indices=tensor([[0, 1, 2],
[0, 2, 1],
[0, 2, 1],
[0, 1, 2]]))
torch.sort()排序函数的参数
最新推荐文章于 2024-08-29 18:24:36 发布
该博客介绍了如何在PyTorch中使用`torch.sort()`函数对数据进行排序。通过示例展示了如何在指定维度上进行升序或降序排序,并打印了排序后的数据及其对应的索引。此内容对于理解和应用PyTorch中的数据处理非常有用。
摘要由CSDN通过智能技术生成