import torch a=torch.randn(5,3) print(a,torch.max(a,1),torch.max(a,1)[1]) ''' tensor([[ 1.0588, 0.4321, -1.1318], [-1.5585, -0.5181, -0.6880], [ 2.1610, 1.1872, -0.7473], [-1.9867, -1.0905, 0.6602], [-0.5579, -0.2183, -1.0601]]) (tensor([ 1.0588, -0.5181, 2.1610, 0.6602, -0.2183]), tensor([0, 1, 0, 2, 1])) tensor([0, 1, 0, 2, 1]) torch.max函数对于输入的torch.tensor沿着其某一个维度上计算最大值 返回两个 第一项为tensor沿着某个维度上最大的数值,第二项为最大值的位置索引值 '''
torch.max
最新推荐文章于 2024-02-09 23:37:29 发布