axis这个参数表达的意思是某个维度。axis = 0 代表第一个维度,axis = 1 代表第二个维度
比如:下面这个三维数组
>>> X1 = np.random.randint(2,20,[2,3,4])
>>> X1
array([[[11, 10, 8, 5],
[13, 7, 9, 15],
[ 2, 13, 3, 9]],
[[ 7, 10, 7, 19],
[15, 15, 12, 6],
[ 6, 12, 6, 13]]])
我们查看一下他的维度
>>> X1.ndim
3
然后我们用求和X1.sum(axis = 0)来做一个示范,上面说了axis的意思是第一个维度,那.sum(axis = 0)这句话可以理解为在第一个维度上求和,相当于压缩第一个维度。我们来看求和之后的结果和它的维度。
>>> X1.sum(axis = 0)
array([[18, 20, 15, 24],
[28, 22, 21, 21],
[ 8, 25, 9, 22]])
>>> X1.sum(axis = 0).ndim
2
>>> X1.sum(axis = 0).shape
(3, 4)
>>>
我们可以看到,求和之后。原先的(2,3,4)的结构,变成了(3,4)。第一个维度被压缩了。
以此类推,X1.sum(axis = 1)这句代码,它所产生的结果应该是(2,4)结构的,类似于如下的样子
array([[x, x, x, x],
[x, x, x, x]])
根据结果的样子,我们很明显可以发现,只要把上面X1的3行压缩就可以达成
array([[[11, 10, 8, 5],
[13, 7, 9, 15],
[ 2, 13, 3, 9]],红色的矩阵,竖着相加[[ 7, 10, 7, 19],
[15, 15, 12, 6],
[ 6, 12, 6, 13]]])黄色的矩阵,竖着相加
结果就是
array([[26, 30, 20, 29],
[28, 37, 25, 38]])
尝试感受一下。理解之后你会感觉非常简单。下面来一个四维数组例子,代码有点长,还没理解的可以看看,感受一下
X1 = np.random.randint(2,20,[2,3,4,3])
>>> X1
array([[[[ 3, 19, 16],
[18, 14, 15],
[ 2, 13, 13],
[13, 19, 9]],
[[16, 17, 17],
[ 7, 3, 15],
[17, 8, 15],
[ 9, 14, 11]],
[[ 2, 17, 3],
[ 8, 6, 10],
[15, 10, 19],
[ 2, 5, 18]]],
[[[16, 7, 6],
[ 8, 13, 18],
[18, 14, 19],
[ 3, 14, 3]],
[[14, 13, 6],
[ 8, 15, 5],
[13, 4, 9],
[ 7, 4, 15]],
[[ 5, 18, 12],
[12, 10, 19],
[ 7, 8, 10],
[13, 19, 13]]]])
如果我们要对这个(2,3,4,3)结构的数组,进行X1.sum(axis = 0)的运算,得到的运算结果的shape就是(3,4,3)。所以最后的结果应该长这个样子:
array([[[x, x, x],
[x, x, x],
[x, x, x],
[x, x, x]],
[[x, x, x],
[x, x, x],
[x, x, x],
[x, x, x]],
[[x, x, x],
[x, x, x],
[x, x, x],
[x, x, x]]])
观察一下原先X1的样子,和应该得到的结果的样子,不难想象要如何计算。
array([[[[ 3, 19, 16],
[18, 14, 15],
[ 2, 13, 13],
[13, 19, 9]],[[16, 17, 17],
[ 7, 3, 15],
[17, 8, 15],
[ 9, 14, 11]],[[ 2, 17, 3],
[ 8, 6, 10],
[15, 10, 19],
[ 2, 5, 18]]],
[[[16, 7, 6],
[ 8, 13, 18],
[18, 14, 19],
[ 3, 14, 3]],[[14, 13, 6],
[ 8, 15, 5],
[13, 4, 9],
[ 7, 4, 15]],[[ 5, 18, 12],
[12, 10, 19],
[ 7, 8, 10],
[13, 19, 13]]]])相同颜色的矩阵相加
最后得到的结果就是
array([[[19, 26, 22],
[26, 27, 33],
[20, 27, 32],
[16, 33, 12]],
[[30, 30, 23],
[15, 18, 20],
[30, 12, 24],
[16, 18, 26]],
[[ 7, 35, 15],
[20, 16, 29],
[22, 18, 29],
[15, 24, 31]]])
讲到这里,大家应该差不多明白了吧,我就是这样一步步理解的。
原理方面,大家可以看一下这篇文章Python中np.sum()对axis的个人理解,超详细,里面的数学公式,可以帮助更好的理解。