np.argmax&torch.max()对比

argmax函数


通俗来说:在axis的增长方向上求最大值

np.argmax()

import numpy as np
a = np.array([
              [
                  [1, 5, 5, 2],
                  [9, -6, 2, 8],
                  [-3, 7, -9, 1]
              ],
 
              [
                  [-1, 7, -5, 2],
                  [9, 6, 2, 8],
                  [3, 7, 9, 1]
              ],
              
           	  [
                  [21, 6, -5, 2],
                  [9, 36, 2, 8],
                  [3, 7, 79, 1]
              ]
            ])
b=np.argmax(a, axis=0)#对于三维度矩阵,a有三个方向a[0][1][2]
#当axis=0时,是在a[0]方向上找最大值,即三个矩阵做比较,具体
						  [1, 5, 5, 2],
                          [-1, 7, -5, 2],
                          [21, 6, -5, 2],
#一共有三个,所以最终得到的结果b就为34列矩阵
print(b)
[[2 1 0 0]
 [0 2 0 0]
 [1 0 2 0]]
 
c=np.argmax(a, axis=1)#对于三维度矩阵,a有三个方向a[0][1][2]
#当axis=1时,是在a[1]方向上找最大值,即在列方向比较,此时就是指在每个矩阵内部的列方向上进行比较
                    	[1, 5, 5, 2],
                  		[9, -6, 2, 8],
                  	    [-3, 7, -9, 1]
#一共有三个,所以最终得到的结果b就为34列矩阵
print(c)
[[1 2 0 1]
 [1 0 2 1]
 [0 1 2 1]]
 
d=np.argmax(a, axis=2)#对于三维度矩阵,a有三个方向a[0][1][2]
#当axis=2时,是在a[2]方向上找最大值,即在行方向比较,此时就是指在每个矩阵内部的行方向上进行比较
                   	  [1, 5, 5, 2],
                   	  [9, -6, 2, 8],
                   	  [-3, 7, -9, 1]
#寻找第一行的最大值,可以看出第一行[1, 5, 5, 2]最大值为5,,索引值为1
print(d)
[[1 0 1]
 [1 0 2]
 [0 1 2]]
##################################################################
# 第一个矩阵,取最后一行的所有列
m=np.argmax(a[0, -1, :])
print(m)
#1
 
# 第二个矩阵,取第三行的所有列
h=np.argmax(a[1, 2, :])
print(h)
#2

# 第二个矩阵,取所有行的第三列
g=np.argmax(a[1,:, 2])
print(g)
#2
import numpy as np
import numpy as np

arrary = np.array([
    [
        [1, 5, 5, 2],
        [9, -6, 2, 8],
        [-3, 7, -9, 1]
    ],

    [
        [-1, 7, -5, 2],
        [9, 6, 2, 8],
        [3, 7, 9, 1]
    ],

    [
        [21, 6, -5, 2],
        [9, 36, 2, 8],
        [3, 7, 79, 1]
    ]
    ])
print(arrary.shape)

a = np.argmax(arrary, axis=0)
b = np.argmax(arrary, axis=1)
c = np.argmax(arrary, axis=2)
print('argmax axis = 0 is ', a)
print('argmax axis = 1 is ', b)
print('argmax axis = 2 is ', c)
import torch
d = torch.from_numpy(arrary)
print(arrary)
d = torch.argmax(d, dim= -0)
print('torch argmax is :', d)
torch.argmax(arrary, dim = -1)
#dim可以有-3-2-1012
torch.max(a,0) 返回每一列中最大值的那个元素,且返回其索引(返回最大元素在这一列的行索引)
axis=0表示以行的维度为基准,行上的所有数据所在列上的最大值,通俗来说:在axis的增长方向上求最大值
torch.max(a,1) 返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引)
axis=1表示以列的维度为基准,列上的所有数据所在行上的最大值,通俗来说:在axis的增长方向上求最大值
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值