-
函数作用
torch.gt(a,b)函数比较a中元素大于(这里是严格大于)b中对应元素,大于则为1,不大于则为0,这里a为Tensor,b可以为与a的size相同的Tensor或常数。 -
代码示例
>>> import torch
>>> a=torch.randn(2,4)
>>> a
tensor([[-0.5466, 0.9203, -1.3220, -0.7948],
[ 2.0300, 1.3090, -0.5527, -0.1326]])
>>> b=torch.randn(2,4)
>>> b
tensor([[-0.0160, -0.3129, -1.0287, 0.5962],
[ 0.3191, 0.7988, 1.4888, -0.3341]])
>>> torch.gt(a,b) #得到a中比b中元素大的位置
tensor([[0, 1, 0, 0],
[1, 1, 0, 1]], dtype=torch.uint8)
>>> torch.gt(b,a) #b中比a中大
tensor([[1, 0, 1, 1],
[0, 0, 1, 0]], dtype=torch.uint8)
>>> torch.gt(a,1)
tensor([[0, 0, 0, 0],
[1, 1, 0, 0]], dtype=torch.uint8)
>>> c=torch.Tensor([[1,2,3],[4,5,6]])
>>> d=torch.Tensor([[1,1,3],[5,5,5]])
>>> torch.gt(c,d) #必须是严格大于才为1
tensor([[0, 1, 0],
[0, 0, 1]], dtype=torch.uint8)