对于numpy包中的axis参数的理解
在numpy中,对于多维数组进行sum,mean,min,max,sort的操作时,均会涉及到axis这一参数。那么axis具体是什么呢?
我们先引入一个切片(slices)的概念:
- 如果 k k k维数组 A ∈ R n 1 × n 2 × . . . × n k A \in \mathbb{R}^{n_1\times n_2 \times ...\times n_k} A∈Rn1×n2×...×nk ,则 A A A的Slices为固定其中1个索引位置之后形成的 k − 1 k-1 k−1维数组
ok,有了slices的概念,我们如何对高维数组进行sum,mean,min,max,sort之类的运算呢?
以sum函数为例。对于一个 k k k维数组 A ∈ R n 1 × n 2 × . . . × n k A \in \mathbb{R}^{n_1\times n_2 \times ...\times n_k} A∈Rn1×n2×...×nk,在axis=i上进行运算就是在第i个维度上对相应的slices进行运算。
以sum函数在axis=i上运算为例,这个过程就相当于计算:
n p . s u m ( A , a x i s = i ) = ∑ j = 1 n i x [ : , : , . . . , : , j , . . . , : ] . np.sum(A,axis=i) = \sum_{j=1}^{n_i}x[:,:,...,:,j,...,:]. np.sum(A,axis=i)=