导读
最近有时间看一些目标检测项目的代码(基于Pytorch),里边很多Pytorch的相关操作都忘记了,特来此记录一下,用以加深记忆,而且还能以备一样处境的同学前来查询。今天的主角是torch.argmax(input, dim, keepdim=False)。
官方文档地址
https://pytorch.org/docs/stable/generated/torch.argmax.html
torch.argmax(input) → LongTensor
Returns the indices of the maximum value of all elements in the input tensor.
根据官方的解释,该函数可以返回输入张量中所有元素的最大值的索引。当然这只是最初级的用法,根据输入参数的不同,其返回的结果也不同。下面我们一起了解它的参数都有哪些作用。
参数解析
Parameters
- input (Tensor) : the input tensor.
- dim (int) :the dimension to reduce. If None, the argmax of the flattened input is returned.
- keepdim (bool) : whether the output tensor has dim retained or not. Ignored if dim=None.
这是官网上对参数的解释,input就是我们输入的要操作的张量;dim是我们选择的要在张量的哪个维度上进行计算,输出这个维度最大值的索引,这里一行元素变成一个索引,所以官网中用了reduce;keepdim是询问输出是否与输入保持一样的形状,默认是不保持(False)。
举例演示
首先输入一个张量,注意我们输入的这个张量的shape为[5, 9]
>>> x = torch.randn(5,9)
>>> print(x)
tensor([[ 0.3918, 0.3978, 0.2819, -0.8487, -1.0499, 0.2124, -1.3527, -1.5335,
1.1050],
[ 0.8450, -0.3717, -0.4705, -0.4024, 2.1019, -0.8545, 1.9085, 0.5792,
-0.4279],
[ 0.1993, -0.2887, 0.4467, 0.4878, 1.4934, -1.3862, 0.3576, -0.2363,
-2.0700],
[ 0.0536, 0.9385, 1.2661, -0.3469, -0.5772, -0.7822, 0.8315, -1.7256,
-0.4979],
[ 1.1592, -0.1604, 0.2798, 0.5974, 0.1782, -2.3354, -1.7775, -0.8366,
1.8993]])
接下来,现在第一个维度上进行操作
>>> torch.argmax(x,dim=0)
tensor([4, 3, 3, 4, 1, 0, 1, 1, 4])
第一个维度是行,即按行计算,我们看到结果输出的维度为9,正好是输入张量x的列数。torch.argmax()的计算方式如下:
每次在所有行的相同位置取元素,然后计算取得元素集合的最大值索引。
第一次取所有行的第一位元素,x[:, 0], 得到
tensor([0.3918, 0.8450, 0.1993, 0.0536, 1.1592])
第二次取所有行的第二位元素,x[:, 1], 得到
tensor([0.3978, -0.3717, -0.2887, 0.9385, -0.1604])
依次类推,x有9列,我们也可以取9次,所有取的结果如下:
tensor([ 0.3918, 0.8450, 0.1993, 0.0536, 1.1592])
tensor([ 0.3978, -0.3717, -0.2887, 0.9385, -0.1604])
tensor([ 0.2819, -0.4705, 0.4467, 1.2661, 0.2798])
tensor([-0.8487, -0.4024, 0.4878, -0.3469, 0.5974])
tensor([-1.0499, 2.1019, 1.4934, -0.5772, 0.1782])
tensor([-1.3527, 1.9085, 0.3576, 0.8315, -1.7775])
tensor([-1.5335, 0.5792, -0.2363, -1.7256, -0.8366])
tensor([ 1.1050, -0.4279, -2.0700, -0.4979, 1.8993])
然后分别计算以上每个张量中元素的最大值的索引,便得到tensor([4, 3, 3, 4, 1, 0, 1, 1, 4])
同理,按照列来操作也是一样的思路,这里就不详细说了,看结果:
>>> torch.argmax(x,dim=1)
tensor([8, 4, 4, 2, 8])
经过上边例子的演示,我们可以知道torch.argmax(input,dim)可以返回input中dim维度上的最大值索引。
我们给x在目标检测中赋予具体的含义,假如x的形状为[num_bbox, anchor],那么x便是5个预测框分别与9个anchor计算得到的交并比,我们要选出来与预测框交并比最大的那个anchor,用来回归预测框越来越接近GT。这时候就要用到torch.argmax()找到与bbox交并比最大的anchor的序号。
>>> torch.argmax(x,dim=1)
tensor([8, 4, 4, 2, 8])
即与第一个预测框交并比最大的是第9个anchor,与第二个预测框交并比最大的是第5个anchor…