1. 常用函数
比较函数中有一些是逐元素比较,操作类似逐元素操作,还有一些类似归并操作,常用的比较函数如下表所示。
表中第一行的比较操作已经实现了运算符重载,因此可以使用 a>=b
,a>b
,a !=b
和 a == b
,其返回的结果是一个 ByteTensor
,可用来选取元素。
max/min
操作有些特殊,以 max
为例,有以下三种使用情况:
t.max(tensor)
: 返回tensor
中最大的一个数;t.max(tensor,dim)
: 指定维上最大数,返回tensor
和下标;t.max(tensor1,tensor2)
: 比较两个tensor
相比较大的元素;
2. 使用示例
torch.gt
、torch.lt
、torch.ge
、torch.le
、torch.eq
、torch.ne
的函数参数和返回值是类似的,都如下所示:
-
Args:
input (Tensor): the tensor to compare
other (Tensor or float): the tensor or value to compare
out (Tensor, optional): the output tensor that must be aBoolTensor
-
Returns:
Tensor: Atorch.BoolTensor
containing a True at each location where comparison is true
2.1 torch.gt
In [1]: import torch as t
In [2]: a = t.Tensor([[1,2],[3,4]])
In [3]: a
Out[3]:
tensor([[1., 2.],
[3., 4.]])
In [4]: a.gt(4)
Out[4]:
tensor([[False, False],
[False, False]])
In [7]: a.gt(t.Tensor([[1,1], [3, 3]]))
Out[7]:
tensor([[False, True],
[False, True]])
2.2 torch.lt
函数参数同 torch.gt
。
In [9]: a.lt(4)
Out[9]:
tensor([[ True, True],
[ True, False]])
In [10]: a.lt(t.Tensor([[1,1], [3, 3]]))
Out[10]:
tensor([[False, False],
[False, False]])
In [11]:
2.3 torch.ge
In [12]: a.ge(4)
Out[12]:
tensor([[False, False],
[False, True]])
In [13]: a.ge(t.Tensor([[1,1], [3, 3]]))
Out[13]:
tensor([[True, True],
[True, True]])
2.4 torch.le
In [14]: a.le(4)
Out[14]:
tensor([[True, True],
[True, True]])
In [15]: a.le(t.Tensor([[1,1], [3, 3]]))
Out[15]:
tensor([[ True, False],
[ True, False]])
In [16]:
2.5 torch.eq
In [16]: a.eq(4)
Out[16]:
tensor([[False, False],
[False, True]])
In [17]: a.eq(t.Tensor([[1,1], [3, 3]]))
Out[17]:
tensor([[ True, False],
[ True, False]])
2.6 torch.ne
In [18]: a.ne(4)
Out[18]:
tensor([[ True, True],
[ True, False]])
In [19]: a.ne(t.Tensor([[1,1], [3, 3]]))
Out[19]:
tensor([[False, True],
[False, True]])
2.7 torch.topk
函数定义如下:
topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
参数说明:
Returns the :attr:`k` largest elements of the given :attr:`input` tensor along
a given dimension.
If :attr:`dim` is not given, the last dimension of the `input` is chosen.
If :attr:`largest` is ``False`` then the `k` smallest elements are returned.
A namedtuple of `(values, indices)` is returned, where the `indices` are the indices
of the elements in the original `input` tensor.
The boolean option :attr:`sorted` if ``True``, will make sure that the returned
`k` elements are themselves sorted
Args:
input (Tensor): the input tensor.
k (int): the k in "top-k"
dim (int, optional): the dimension to sort along
largest (bool, optional): controls whether to return largest or
smallest elements
sorted (bool, optional): controls whether to return the elements
in sorted order
out (tuple, optional): the output tuple of (Tensor, LongTensor) that can be
optionally given to be used as output buffers
In [21]: a
Out[21]:
tensor([[1., 2.],
[3., 4.]])
In [22]: a.topk(2)
Out[22]:
torch.return_types.topk(
values=tensor([[2., 1.],
[4., 3.]]),
indices=tensor([[1, 0],
[1, 0]]))
In [24]: a.topk(1, dim=1)
Out[24]:
torch.return_types.topk(
values=tensor([[2.],
[4.]]),
indices=tensor([[1],
[1]]))
In [25]:
2.8 torch.sort
函数定义如下:
sort(input, dim=-1, descending=False, out=None) -> (Tensor, LongTensor)
函数参数如下:
Sorts the elements of the :attr:`input` tensor along a given dimension
in ascending order by value.
If :attr:`dim` is not given, the last dimension of the `input` is chosen.
If :attr:`descending` is ``True`` then the elements are sorted in descending
order by value.
A namedtuple of (values, indices) is returned, where the `values` are the
sorted values and `indices` are the indices of the elements in the original
`input` tensor.
Args:
input (Tensor): the input tensor.
dim (int, optional): the dimension to sort along
descending (bool, optional): controls the sorting order (ascending or descending)
out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can
be optionally given to be used as output buffers
In [28]: b = t.randn(2,3)
In [29]: b
Out[29]:
tensor([[-0.1936, -1.8862, 0.1491],
[ 0.8152, 1.1863, -0.4711]])
In [30]: b.sort()
Out[30]:
torch.return_types.sort(
values=tensor([[-1.8862, -0.1936, 0.1491],
[-0.4711, 0.8152, 1.1863]]),
indices=tensor([[1, 0, 2],
[2, 0, 1]]))
In [31]: b.sort(dim=0)
Out[31]:
torch.return_types.sort(
values=tensor([[-0.1936, -1.8862, -0.4711],
[ 0.8152, 1.1863, 0.1491]]),
indices=tensor([[0, 0, 1],
[1, 1, 0]]))
In [32]:
2.9 torch.max
函数定义如下:
max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
函数参数如下:
Args:
input (Tensor): the input tensor.
dim (int): the dimension to reduce.
keepdim (bool): whether the output tensor has :attr:`dim` retained or not. Default: ``False``.
out (tuple, optional): the result tuple of two output tensors (max, max_indices)
或者:
Args:
input (Tensor): the input tensor.
other (Tensor): the second input tensor
out (Tensor, optional): the output tensor.
示例如下:
In [2]: import torch as t
In [4]: a = t.Tensor([[1,2], [3,4]])
In [5]: a
Out[5]:
tensor([[1., 2.],
[3., 4.]])
In [7]: a.max()
Out[7]: tensor(4.)
In [8]: b = t.Tensor([[2,0], [2,6]])
In [9]: a.max(b)
Out[9]:
tensor([[2., 2.],
[3., 6.]])
In [10]: a.max(dim=0)
Out[10]:
torch.return_types.max(
values=tensor([3., 4.]),
indices=tensor([1, 1]))
In [11]: a.max(dim=1)
Out[11]:
torch.return_types.max(
values=tensor([2., 4.]),
indices=tensor([1, 1]))
In [12]:
2.10 torch.min
函数定义如下:
min(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
min(input, other, out=None) -> Tensor
min(input) -> Tensor
函数参数如下:
Args:
input (Tensor): the input tensor.
dim (int): the dimension to reduce.
keepdim (bool): whether the output tensor has :attr:`dim` retained or not.
out (tuple, optional): the tuple of two output tensors (min, min_indices)
或者:
Args:
input (Tensor): the input tensor.
other (Tensor): the second input tensor
out (Tensor, optional): the output tensor.
使用示例:
In [13]: a = t.Tensor([[1,2], [3,4]])
In [14]: a
Out[14]:
tensor([[1., 2.],
[3., 4.]])
In [15]: a.min()
Out[15]: tensor(1.)
In [16]: b = t.Tensor([[2,0], [2,6]])
In [17]: b
Out[17]:
tensor([[2., 0.],
[2., 6.]])
In [18]: a.min(b)
Out[18]:
tensor([[1., 0.],
[2., 4.]])
In [19]: a.min(dim=0)
Out[19]:
torch.return_types.min(
values=tensor([1., 2.]),
indices=tensor([0, 0]))
In [20]: a.min(dim=1)
Out[20]:
torch.return_types.min(
values=tensor([1., 3.]),
indices=tensor([0, 0]))
In [21]: