在菜鸟教程上学习numpy,提及numpy的高级索引,有个例子是给出一个二维矩阵,取出这个矩阵的四个角的元素,代码如下:
import numpy as np
x = np.array([[ 0, 1, 2],[ 3, 4, 5],[ 6, 7, 8],[ 9, 10, 11]])
print ('我们的数组是:' )
print (x)
print ('\n')
rows = np.array([[0,0],[3,3]])
cols = np.array([[0,2],[0,2]])
y = x[rows,cols]
print ('这个数组的四个角元素是:')
print (y)
输出结果应该是:
我们的数组是:
[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]
[ 9 10 11]]
这个数组的四个角元素是:
[[ 0 2]
[ 9 11]]
这一段代码我死也没看明白是怎么回事,主要障碍是rows和cols这两个矩阵是如何构建x的索引矩阵的。后来看了一篇博文:《NumPy之四:高级索引和索引技巧》才算弄清楚。
在理解博文的基础上,解释如下:
前述代码给出超额二维矩阵应该是这个样子的:
x =
[[0 1 2]
[3 5 5]
[6 7 8]
[9 10 11]]
以下代码分别构造了行索引矩阵rows和列索引矩阵cols:
rows = numpy.array([[0, 0], [3, 3]])
cols = numpy.array([[0, 2], [0, 2]])
这时行索引矩阵长这样:
rows =
[[r00 r01]
[r10 r11]]
=
[[0 0]
[3 3]]
而列索引矩阵长这模样:
cols =
[[c00 c01]
[c10 c11]]
=
[[0 2]
[0 2]]
由这两个索引矩阵构成的四个索引坐标对为:
(r00, c00), (r01, c01), (r10, c10), (r11, c11)
索引矩阵为:
[[(r00, c00) (r01, c01)]
[(r10, c10) (r11, c11)]]
=
[[(0, 0), (0, 2)]
[(3, 0), (3, 2)]]
用这四个索引坐标去从x矩阵中取出四个元素:
y(0, 0) = x(r00, c00) = x(0, 0) = 0
y(0, 1) = x(r01, c01) = x(0, 2) = 2
y(1, 0) = x(r10, c10) = x(3, 0) = 9
y(1, 1) = (r11, c11) = x(3, 2) = 11
因此,最终四个角构成的矩阵为:
y = x[rows, cols] =
[[0 2]
[9 11]]
其实吧,这事儿得反过来看。x矩阵的四个角的坐标分别是(0, 0)、(0, 2)、(3, 0)和(3, 2),这四个坐标对的第一个值就是对应x矩阵四个角的行坐标(即,四个角的行坐标分别是0, 0, 3, 3),第二个值对应x矩阵四个角的列坐标(即,四个角的列坐标分别是0, 2, 0, 2)。由此,可分别构造行坐标列表和列坐标列表如下:
rows = [0, 0, 3, 3]
cols = [0, 2, 0, 2]
上图按列对齐看,每一列的两个数就组成了一组坐标对。
由于四个角要构成矩阵样式,因此行列坐标也分别用矩阵表达:
rows = numpy.array([[0, 0], [3, 3]])
cols = numpy.array([[0, 2], [0, 2]])
最后,用y = x[rows, cols]就取出了x的四个角,并构成新的矩阵如下:
[[0 2]
[9 11]]