import numpy as np
a=np.array([[1,1,1],[2,2,2],[0,3,6]])
a
array([[1, 1, 1],
[2, 2, 2],
[0, 3, 6]])
b1=np.argmax(a)
b1
8
b2=np.argmax(a,axis=0)
b2
array([1, 2, 2], dtype=int64)
b3=np.argmax(a,axis=1)
b3
array([0, 0, 2], dtype=int64)