1.函数的区别
torch.gt()、torch.ge()、torch.lt()、torch.le()、torch,eq()、torch.ne()分别对应“>”、“≥”、“<”、“≤”、“=”、“≠”
函数名 | 功能 |
torch.gt() | 大于 |
torch.ge() | 大于或等于 |
torch.lt() | 小于 |
torch.le() | 小于或等于 |
torch.eq() | 等于 |
torch.ne() | 不等于 |
2.函数介绍
这里只以torch.gt()函数举例,其他函数用法与torch.gt()函数类似。
Pytorch官方介绍:
torch.gt()函数原型:
torch.gt(input, other, *, out=None) → Tensor
返回值为input与other大小比较的结果。若input大于other返回True,否则返回False。如果input与other的shape相同,那么则在对应位置进行比较;如果input与other的shape不同,则会用到广播机制。
当input与other的shape相同时,直接对应位置的值进行比较。
# input与other的shape相同时,对应位置进行大小比较
input = torch.tensor([[1, 1, 1], [3, 3, 3]])
print(input)
other = torch.tensor([[2, 2, 2], [2, 2, 2]])
print(other)
result = torch.gt(input, other)
print(result)
output:
tensor([[1, 1, 1],
[3, 3, 3]])
tensor([[2, 2, 2],
[2, 2, 2]])
tensor([[False, False, False],
[ True, True, True]])
当input与other的shape不同时,使用广播机制。在下面的例子中input的shape是[2, 3],other的shape是[3]。可以看到input有两个维度,other只有一个维度,此时会进行广播机制操作。使用other与input的每行进行比较,返回比较值。
# input与other的shape不同时,会采用广播机制
input = torch.tensor([[1, 1, 1], [3, 3, 3]])
print(input)
other = torch.tensor([2, 2, 2])
print(other)
result = torch.gt(input, other)
print(result)
output:
tensor([[1, 1, 1],
[3, 3, 3]])
tensor([2, 2, 2])
tensor([[False, False, False],
[ True, True, True]])
torch.gt()函数还可以被Tensor类型的对象直接调用,即torch.gt(input, other)等价于input.gt(other)。
# torch.gt()函数的第二种用法
input = torch.tensor([[1, 1, 1], [3, 3, 3]])
print(input)
other = torch.tensor([[2, 2, 2]])
print(other)
result = input.gt(other)
print(result)
output:
tensor([[1, 1, 1],
[3, 3, 3]])
tensor([2, 2, 2])
tensor([[False, False, False],
[ True, True, True]])
torch.gt()函数中的第三个参数是个可选参数。官方给的解释是可以传入一个数字或者一个与input相同shape的Tensor。当使用到out参数后,返回的值“True”替换成“1”,“False”替换成“0”。
# 使用out参数
input = torch.tensor([[1, 1, 1], [3, 3, 3]])
print(input)
other = torch.tensor([2, 2, 2])
print(other)
out = torch.zeros_like(input)
result = torch.gt(input, other, out=out)
print(result)
output:
tensor([[1, 1, 1],
[3, 3, 3]])
tensor([2, 2, 2])
tensor([[0, 0, 0],
[1, 1, 1]])
在实际测试时,给out传入一个数字,程序会抛出异常,如下图所示。提示input,other,out这三个参数都应该是Tensor格式。因此给out传入一个与input相同shape的Tensor是较为恰当的做法。
3.官方文档
gt:https://pytorch.org/docs/stable/generated/torch.gt.html?highlight=torch+gt#torch.gt
ge:https://pytorch.org/docs/stable/generated/torch.ge.html?highlight=torch+ge#torch.ge
lt:https://pytorch.org/docs/stable/generated/torch.lt.html?highlight=torch+lt#torch.lt
le:https://pytorch.org/docs/stable/generated/torch.le.html?highlight=torch+le#torch.le
eq:https://pytorch.org/docs/stable/generated/torch.eq.html?highlight=torch+eq#torch.eq
ne:https://pytorch.org/docs/stable/generated/torch.ne.html?highlight=torch+ne#torch.ne