文章目录
- 常用的比较操作
-
- 1.torch.allclose()
- 2.torch.argsort()
- 3.torch.eq()
- 4.torch.equal()
- 5.torch.greater_equal()
- 6.torch.gt()
- 7.torch.isclose()
- 8.torch.isfinite()
- 9.torch.isif()
- 10.torch.isposinf()
- 11.torch.isneginf()
- 12.torch.isnan()
- 13.torch.kthvalue()
- 14.torch.less_equal()
- 15.torch.maximum()
- 16.torch.fmax()
- 17.torch.ne()
- 18.torch.sort()
- 19.torch.topk()
常用的比较操作
1.torch.allclose()
torch.allclose() 是 PyTorch 中用于比较两个张量是否在给定的容差范围内近似相等的函数。它可以用于比较浮点数张量之间的相等性。
torch.allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False)
"""
input:第一个输入张量。
other:第二个输入张量。
rtol:相对容差(relative tolerance),默认为 1e-05。
atol:绝对容差(absolute tolerance),默认为 1e-08。
equal_nan:一个布尔值,指示是否将 NaN 视为相等,默认为 False。
"""
import torch
# 比较两个张量是否近似相等
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([1.0001, 2.0002, 3.0003])
is_close = torch.allclose(x, y, rtol=1e-03, atol=1e-05)
print(is_close)# True
2.torch.argsort()
torch.argsort() 是 PyTorch 中用于对张量进行排序并返回排序后的索引的函数。它返回一个新的张量,其中每个元素表示原始张量中对应位置的元素在排序后的顺序中的索引值。
torch.argsort(input, dim=-1, descending=False, *, out=None)
"""
input:输入张量。
dim:指定排序的维度,默认为 -1,表示最后一个维度。
descending:一个布尔值,指示是否按降序排序,默认为 False。
out:可选参数,用于指定输出张量的位置。
"""
import torch
# 对张量进行排序并返回索引
x = torch.tensor([3, 1, 4, 2])
sorted_indices = torch.argsort(x)
print(sorted_indices)
# tensor([1, 3, 0, 2])
3.torch.eq()
torch.eq() 是 PyTorch 中用于执行元素级别相等性比较的函数。它比较两个张量的对应元素,并返回一个新的布尔张量,其中元素为 True 表示对应位置的元素相等,元素为 False 表示对应位置的元素不相等。
torch.eq(input, other, out=None)
"""
input:第一个输入张量。
other:第二个输入张量。
out:可选参数,用于指定输出张量的位置。
"""
import torch
# 执行元素级别的相等性比较
x = torch.tensor([1, 2, 3])
y = torch.tensor([1, 2, 4])
result = torch.eq(x, y)
print(result)# tensor([ True, True, False])
4.torch.equal()
torch.equal() 是 PyTorch 中用于检查两个张量是否在元素级别上完全相等的函数。它返回一个布尔值,指示两个张量是否具有相同的形状和相同的元素值。
torch.equal(input, other)
"""
input:第一个输入张量。
other:第二个输入张量。
"""
import torch
# 检查两个张量是否完全相等
x = torch.tensor([1, 2, 3])
y = torch.tensor([1, 2, 3])
is_equal = torch.equal(x, y)
print(is_equal)# True