np.sum()函数中axis参数的理解:
import numpy as np
a = np.array([[[1, 2, 3, 2],[1, 2, 3, 1], [2, 3, 4, 1]],
[[1, 0, 2, 0], [2, 1, 2, 0], [2, 1, 1, 1]]])
print(a.sum(axis=0))
'''
[[2 2 5 2]
[3 3 5 1]
[4 4 5 2]]
相当于把第0维压缩(对应值相加)成1:2*3*4 -> (1*)3*4
'''
print(a.sum(axis=1))
'''
[[4 7 10 4]
[5 2 5 1]]
相当于把第1维压缩(对应值相加)成1:2*3*4 -> 2*(1*)4
'''
print(a.sum(axis=2))
'''
[[ 8 7 10]
[ 3 5 5]]
相当于把第2维压缩(对应值相加)成1:2*3*4 -> 2*3(*1)
'''