Pytorch中比较大小的相关函数:torch.gt()、torch.ge、torch.lt()、torch.le()、torch.eq()、torch.ne()

1.函数的区别

        torch.gt()、torch.ge()、torch.lt()、torch.le()、torch,eq()、torch.ne()分别对应“>”、“≥”、“<”、“≤”、“=”、“≠”

函数名功能
torch.gt()大于
torch.ge()大于或等于
torch.lt()小于
torch.le()小于或等于
torch.eq()等于
torch.ne()不等于

2.函数介绍

        这里只以torch.gt()函数举例,其他函数用法与torch.gt()函数类似。

        Pytorch官方介绍:

         torch.gt()函数原型:

torch.gt(input, other, *, out=None) → Tensor

        返回值为input与other大小比较的结果。若input大于other返回True,否则返回False。如果input与other的shape相同,那么则在对应位置进行比较;如果input与other的shape不同,则会用到广播机制。

        当input与other的shape相同时,直接对应位置的值进行比较。

# input与other的shape相同时,对应位置进行大小比较
input = torch.tensor([[1, 1, 1], [3, 3, 3]])
print(input)
other = torch.tensor([[2, 2, 2], [2, 2, 2]])
print(other)
result = torch.gt(input, other)
print(result)

output:
tensor([[1, 1, 1],
        [3, 3, 3]])
tensor([[2, 2, 2],
        [2, 2, 2]])
tensor([[False, False, False],
        [ True,  True,  True]])

         当input与other的shape不同时,使用广播机制。在下面的例子中input的shape是[2, 3],other的shape是[3]。可以看到input有两个维度,other只有一个维度,此时会进行广播机制操作。使用other与input的每行进行比较,返回比较值。

# input与other的shape不同时,会采用广播机制
input = torch.tensor([[1, 1, 1], [3, 3, 3]])
print(input)
other = torch.tensor([2, 2, 2])
print(other)
result = torch.gt(input, other)
print(result)

output:
tensor([[1, 1, 1],
        [3, 3, 3]])
tensor([2, 2, 2])
tensor([[False, False, False],
        [ True,  True,  True]])

        torch.gt()函数还可以被Tensor类型的对象直接调用,即torch.gt(input, other)等价于input.gt(other)。

# torch.gt()函数的第二种用法
input = torch.tensor([[1, 1, 1], [3, 3, 3]])
print(input)
other = torch.tensor([[2, 2, 2]])
print(other)
result = input.gt(other)
print(result)

output:
tensor([[1, 1, 1],
        [3, 3, 3]])
tensor([2, 2, 2])
tensor([[False, False, False],
        [ True,  True,  True]])

        torch.gt()函数中的第三个参数是个可选参数。官方给的解释是可以传入一个数字或者一个与input相同shape的Tensor。当使用到out参数后,返回的值“True”替换成“1”,“False”替换成“0”。

# 使用out参数
input = torch.tensor([[1, 1, 1], [3, 3, 3]])
print(input)
other = torch.tensor([2, 2, 2])
print(other)
out = torch.zeros_like(input)
result = torch.gt(input, other, out=out)
print(result)

output:
tensor([[1, 1, 1],
        [3, 3, 3]])
tensor([2, 2, 2])
tensor([[0, 0, 0],
        [1, 1, 1]])

        在实际测试时,给out传入一个数字,程序会抛出异常,如下图所示。提示input,other,out这三个参数都应该是Tensor格式。因此给out传入一个与input相同shape的Tensor是较为恰当的做法。

3.官方文档

gt:https://pytorch.org/docs/stable/generated/torch.gt.html?highlight=torch+gt#torch.gt

ge:https://pytorch.org/docs/stable/generated/torch.ge.html?highlight=torch+ge#torch.ge

lt:https://pytorch.org/docs/stable/generated/torch.lt.html?highlight=torch+lt#torch.lt

le:https://pytorch.org/docs/stable/generated/torch.le.html?highlight=torch+le#torch.le

eq:https://pytorch.org/docs/stable/generated/torch.eq.html?highlight=torch+eq#torch.eq

ne:https://pytorch.org/docs/stable/generated/torch.ne.html?highlight=torch+ne#torch.ne

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值