数组的轴
数组的轴是一个很重要的概念,也是numpy数组中最不好理解的一个概念,它经常出现在np.sum(), np.max() 这样关键的聚合函数中。
先看一个例子:
interest_score = np.random.randint(10, size=(4,3))
interest_score
array([[0, 8, 1],
[6, 2, 6],
[8, 1, 8],
[6, 4, 0]])
axis = 0 的维度计算求和
np.sum(interest_score, axis=0)
array([20, 15, 15])
axis = 1 的维度计算求和
np.sum(interest_score, axis=1)
array([ 9, 14, 17, 10])
对于二维的,是比较好理解的。
1 轴可以理解为 一个学生的总分
0 轴可以理解为每一科的总分
高维数组
高维数组的轴的理解就有一点难了。
a = np.arange(18).reshape(3, 2, 3)
a
array([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]],
[[12, 13, 14],
[15, 16, 17]]])
这个数组,谁是轴0,1,2呢?
np.sum(a)
153
np.sum 是计算所有的元素。
先感性的认识一下这3个轴
aixs = 0
np.sum(a, axis=0)
array([[18, 21, 24],
[27, 30, 33]])
aixs = 1
np.sum(a, axis=1)
array([[ 3, 5, 7],
[15, 17, 19],
[27, 29, 31]])
axis = 2
np.sum(a, axis=2)
array([[ 3, 12],
[21, 30],
[39, 48]])
答案揭晓:
a.max(axis=0)
array([[12, 13, 14],
[15, 16, 17]])
a.max(axis=1)
array([[ 3, 4, 5],
[ 9, 10, 11],
[15, 16, 17]])
a.max(axis=2)
array([[ 3, 4, 5],
[ 9, 10, 11],
[15, 16, 17]])
本来以为学清楚了,其实还是挺糊涂的。
这个时候可以通过减少维度的形式来解释
[[0,1,2],[3,4,5]]
axis =0 的最大值是
[3,4,5]
axis = 1 的最大值
[2,5]
这样是不是清楚了。
axis 参数非常常见,不光光出现在刚才介绍的 sum 与 max,还有很多其他的聚合函数也会用到,例如 min、mean、argmin(求最小值下标)、argmax(求最大值下标)等。