- 先看个例子
import numpy as np
a = np.arange(16).reshape(2,2,4)
print(a)
print('='*20)
print(a.sum(axis=0))
- 结果
[[[ 0 1 2 3]
[ 4 5 6 7]]
[[ 8 9 10 11]
[12 13 14 15]]]
====================
[[ 8, 10, 12, 14],
[16, 18, 20, 22]]
- 分析
axis
的值是a.shape
元组的下标- 将数组元素用对应下标表示
000 001 002 003
010 011 012 013
100 101 102 103
110 111 112 113
axis=0
表示对于一个元素而言,只有第一个轴变化,其他两个轴不变,从而获得不同元素,比如:000
, 100
是一组,010
, 110
是一组
a.sum(axis=0)
表示将000
与100
对应的元素相加,010
与110
对应的元素相加,以此类推,最后组成一个shape
为(2, 4)
的数组