torch.max学习记录:
首先定义数据:
import torch
a = torch.randn(2,3)
print("a:",a)
结果如下:
a: tensor([[-0.5658, -0.9736, -1.1753],
[ 1.2006, 0.4078, -2.0542]])
torch.
max
(input, dim)
按维度dim 返回最大值
torch.max(a,dim=0) 返回值为一个元组,元组里包含两个值,第一个值为一个每一列中最大元素,第二个值为最大元素在这一列的行索引
返回的又是列又是行的,这句话怎么理解呢?
可以理解成:dim=0,第0个维度表示行,可以想象是你的手,从上往下挤压(对应dim=0,第一行,第二行...从上往下的一个行方向),直到压扁的一个过程(这个要理解清楚)。在此过程中,只保存每一列的最大值,同时记录下这个最大值是第几行(即行索引)。
[ [-0.5658, -0.9736, -1.1753],<