PyTorch 笔记(08)— Tensor 比较运算(torch.gt、lt、ge、le、eq、ne、torch.topk、torch.sort、torch.max、torch.min)

1. 常用函数

比较函数中有一些是逐元素比较,操作类似逐元素操作,还有一些类似归并操作,常用的比较函数如下表所示。
比较函数
表中第一行的比较操作已经实现了运算符重载,因此可以使用 a>=ba>ba !=ba == b,其返回的结果是一个 ByteTensor,可用来选取元素。

max/min 操作有些特殊,以 max 为例,有以下三种使用情况:

  • t.max(tensor) : 返回 tensor 中最大的一个数;
  • t.max(tensor,dim) : 指定维上最大数,返回 tensor 和下标;
  • t.max(tensor1,tensor2) : 比较两个 tensor 相比较大的元素;

2. 使用示例

torch.gttorch.lttorch.getorch.letorch.eqtorch.ne 的函数参数和返回值是类似的,都如下所示:

  • Args:
    input (Tensor): the tensor to compare
    other (Tensor or float): the tensor or value to compare
    out (Tensor, optional): the output tensor that must be a BoolTensor

  • Returns:
    Tensor: A torch.BoolTensor containing a True at each location where comparison is true

2.1 torch.gt

In [1]: import torch as t

In [2]: a = t.Tensor([[1,2],[3,4]])

In [3]: a
Out[3]: 
tensor([[1., 2.],
        [3., 4.]])

In [4]: a.gt(4)
Out[4]: 
tensor([[False, False],
        [False, False]])

In [7]: a.gt(t.Tensor([[1,1], [3, 3]]))
Out[7]: 
tensor([[False,  True],
        [False,  True]])

2.2 torch.lt

函数参数同 torch.gt

In [9]: a.lt(4)
Out[9]: 
tensor([[ True,  True],
        [ True, False]])

In [10]: a.lt(t.Tensor([[1,1], [3, 3]]))
Out[10]: 
tensor([[False, False],
        [False, False]])

In [11]: 

2.3 torch.ge

In [12]: a.ge(4)
Out[12]: 
tensor([[False, False],
        [False,  True]])

In [13]: a.ge(t.Tensor([[1,1], [3, 3]]))
Out[13]: 
tensor([[True, True],
        [True, True]])

2.4 torch.le

In [14]: a.le(4)
Out[14]: 
tensor([[True, True],
        [True, True]])

In [15]: a.le(t.Tensor([[1,1], [3, 3]]))
Out[15]: 
tensor([[ True, False],
        [ True, False]])

In [16]: 

2.5 torch.eq

In [16]: a.eq(4)
Out[16]: 
tensor([[False, False],
        [False,  True]])

In [17]: a.eq(t.Tensor([[1,1], [3, 3]]))
Out[17]: 
tensor([[ True, False],
        [ True, False]])

2.6 torch.ne

In [18]: a.ne(4)
Out[18]: 
tensor([[ True,  True],
        [ True, False]])

In [19]: a.ne(t.Tensor([[1,1], [3, 3]]))
Out[19]: 
tensor([[False,  True],
        [False,  True]])

2.7 torch.topk

函数定义如下:

topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

参数说明:

Returns the :attr:`k` largest elements of the given :attr:`input` tensor along
    a given dimension.
    
    If :attr:`dim` is not given, the last dimension of the `input` is chosen.
    
    If :attr:`largest` is ``False`` then the `k` smallest elements are returned.
    
    A namedtuple of `(values, indices)` is returned, where the `indices` are the indices
    of the elements in the original `input` tensor.
    
    The boolean option :attr:`sorted` if ``True``, will make sure that the returned
    `k` elements are themselves sorted
    
    Args:
        input (Tensor): the input tensor.
        k (int): the k in "top-k"
        dim (int, optional): the dimension to sort along
        largest (bool, optional): controls whether to return largest or
               smallest elements
        sorted (bool, optional): controls whether to return the elements
               in sorted order
        out (tuple, optional): the output tuple of (Tensor, LongTensor) that can be
            optionally given to be used as output buffers
In [21]: a
Out[21]: 
tensor([[1., 2.],
        [3., 4.]])

In [22]: a.topk(2)
Out[22]: 
torch.return_types.topk(
values=tensor([[2., 1.],
        [4., 3.]]),
indices=tensor([[1, 0],
        [1, 0]]))

In [24]: a.topk(1, dim=1)
Out[24]: 
torch.return_types.topk(
values=tensor([[2.],
        [4.]]),
indices=tensor([[1],
        [1]]))

In [25]: 

2.8 torch.sort

函数定义如下:

sort(input, dim=-1, descending=False, out=None) -> (Tensor, LongTensor)

函数参数如下:

    Sorts the elements of the :attr:`input` tensor along a given dimension
    in ascending order by value.
    
    If :attr:`dim` is not given, the last dimension of the `input` is chosen.
    
    If :attr:`descending` is ``True`` then the elements are sorted in descending
    order by value.
    
    A namedtuple of (values, indices) is returned, where the `values` are the
    sorted values and `indices` are the indices of the elements in the original
    `input` tensor.
    
    Args:
        input (Tensor): the input tensor.
        dim (int, optional): the dimension to sort along
        descending (bool, optional): controls the sorting order (ascending or descending)
        out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can
            be optionally given to be used as output buffers
In [28]: b = t.randn(2,3)

In [29]: b
Out[29]: 
tensor([[-0.1936, -1.8862,  0.1491],
        [ 0.8152,  1.1863, -0.4711]])

In [30]: b.sort()
Out[30]: 
torch.return_types.sort(
values=tensor([[-1.8862, -0.1936,  0.1491],
        [-0.4711,  0.8152,  1.1863]]),
indices=tensor([[1, 0, 2],
        [2, 0, 1]]))

In [31]: b.sort(dim=0)
Out[31]: 
torch.return_types.sort(
values=tensor([[-0.1936, -1.8862, -0.4711],
        [ 0.8152,  1.1863,  0.1491]]),
indices=tensor([[0, 0, 1],
        [1, 1, 0]]))

In [32]: 

2.9 torch.max

函数定义如下:

max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)

函数参数如下:

Args:
    input (Tensor): the input tensor.
    dim (int): the dimension to reduce.
    keepdim (bool): whether the output tensor has :attr:`dim` retained or not. Default: ``False``.
    out (tuple, optional): the result tuple of two output tensors (max, max_indices)

或者:

 Args:
     input (Tensor): the input tensor.
     other (Tensor): the second input tensor
     out (Tensor, optional): the output tensor.

示例如下:

In [2]: import torch as t

In [4]: a = t.Tensor([[1,2], [3,4]])

In [5]: a
Out[5]: 
tensor([[1., 2.],
        [3., 4.]])

In [7]: a.max()
Out[7]: tensor(4.)

In [8]: b = t.Tensor([[2,0], [2,6]])

In [9]: a.max(b)
Out[9]: 
tensor([[2., 2.],
        [3., 6.]])

In [10]: a.max(dim=0)
Out[10]: 
torch.return_types.max(
values=tensor([3., 4.]),
indices=tensor([1, 1]))

In [11]: a.max(dim=1)
Out[11]: 
torch.return_types.max(
values=tensor([2., 4.]),
indices=tensor([1, 1]))

In [12]: 

2.10 torch.min

函数定义如下:

min(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
min(input, other, out=None) -> Tensor
min(input) -> Tensor

函数参数如下:

Args:
    input (Tensor): the input tensor.
    dim (int): the dimension to reduce.
    keepdim (bool): whether the output tensor has :attr:`dim` retained or not.
    out (tuple, optional): the tuple of two output tensors (min, min_indices)

或者:

Args:
    input (Tensor): the input tensor.
    other (Tensor): the second input tensor
    out (Tensor, optional): the output tensor.

使用示例:

In [13]: a = t.Tensor([[1,2], [3,4]])

In [14]: a
Out[14]: 
tensor([[1., 2.],
        [3., 4.]])

In [15]: a.min()
Out[15]: tensor(1.)

In [16]: b = t.Tensor([[2,0], [2,6]])

In [17]: b
Out[17]: 
tensor([[2., 0.],
        [2., 6.]])

In [18]: a.min(b)
Out[18]: 
tensor([[1., 0.],
        [2., 4.]])

In [19]: a.min(dim=0)
Out[19]: 
torch.return_types.min(
values=tensor([1., 2.]),
indices=tensor([0, 0]))

In [20]: a.min(dim=1)
Out[20]: 
torch.return_types.min(
values=tensor([1., 3.]),
indices=tensor([0, 0]))

In [21]: 
  • 6
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值