a = np.ones((2, 3, 4), dtype=int)
a.sum(axis=0)
a.sum(axis=1)
a.sum(axis=2)
实际上,axis 的逻辑可理解为如下:
a.sum(axis=0) 等价于 a[0, :, :] + a[1, :, :]
a.sum(axis=1) 等价于 a[:, 0, :] + a[:, 1, :] + a[:, 2, :]
a.sum(axis=2) 等价于 a[:, :, 0] + a[:, :, 1] + a[:, :, 2] + a[:, :, 3]
例如 axis = 2,表示把数组第 2 个维度上的所有切片进行 sum() 操作,大概就是这个意思。