1.import torch # 假设 tensor 是一个 (n,) 形状的 Torch 张量 tensor = torch.tensor([1, 2, 3, ..., n]) # 转换为 (n, 1) 形状的张量 tensor_2d = tensor.unsqueeze(dim=1)
2.
# 假设我们有一个包含重复值的张量 x = torch.tensor([1, 2, 3, 2, 1, 4, 5, 4]) # 只接收排序后的唯一元素 unique_values = torch.unique(x)
3.torch.sort()
是 PyTorch 中的一个函数,用于对张量进行排序。它返回两个值:
-
sorted_tensor:这是排序后的张量,其中元素按照升序排列。
-
indices(可选):这是一个 LongTensor 类型的张量,包含原始输入张量中对应元素在排序后张量中的索引。这意味着,如果你用
indices
索引原始张量,你会得到一个与sorted_tensor
相同顺序的张量。