numpy 的花式索引
numpy的花式索引方式有很多种,特别是多维度索引的情况下比较绕,需要理解清楚。
- 一维的花式索引,
i
数组的值对应的就是a
数组的索引,这样输出了一个新的数组
# input
import numpy as np
a = np.arange(12) ** 2
i = np.array([1, 1, 3, 8, 5])
print(a[i])
# output
[ 1 1 9 64 25]
- 一维变化, 可以看到输出的是一个二维数组了。
- 首先知道
a
的shape=(12, ),i
的shape=(2, 3), 输出的shape是根据a
和i
的shape来确定的,具体的下面详谈。在这个实例中,确定了输出为(2, 3) - 按照
i
提供的值作为a
的索引值,挨个获取值,放入到输出的数组中
# input
import numpy as np
a = np.arange(12) ** 2
i = np.array([[1, 2, 3], [4, 5, 6]])
print(a[i].shape)
print(a[i])
# output
(2, 3)
[[ 1 4 9]
[16 25 36]]
- 简单的多维。
- 规则:当数组
a
是多维的时,单个索引数组指的是第一个维度 a[i]
使用的是单个索引数组i
,所以[1, 2, 3, 4]
各个值分别对应a
的第一个维度"1"
表示获取a[1]=[ 4 5 6 7]"2"
表示获取a[2]=[ 8 9 10 11]"3"
表示获取a[3]=[12 13 14 15]"4"
表示获取a[4]=[16 17 18 19]
- 那么输出就变成了[a[1], a[2], a[3], a[4]], 将真实的值代入,就得到我们的程序输出
- 所以,按照
i
的输出shape=(4, ),每个元素shape=(4, ), 合并之后,输出的shape为(4, 4)
# input
import numpy as np
a = np.arange(20).reshape(5, 4)
i = np.array([1, 2, 3, 4])
print(a[i].shape)
print(a[i])
# output
(4, 4)
[[ 4 5 6 7]
[ 8 9 10 11]
[12 13 14 15]
[16 17 18 19]]
- 复杂的多维
按照i
输出shape=(2, 2),但是获取的每个元素shape=(4, ), 合并之后,输出shape=(2, ,2, 4)
# input
import numpy as np
a = np.arange(20).reshape(5, 4)
i = np.array([[1, 2], [3, 4]])
print(a[i].shape)
print(a[i])
# output
(2, 2, 4)
[[[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]]]
- 复杂的多维(2)
上面示例中,都是a[i]的形式输出的,那么对于a[:, i]来说,":"
代表先是从第0个维度开始获取值到最后一个。c = a[:, i]
就转换成
c = [a[0][i], a[1][i], a[2][i], a[3][i], a[4][i]]
这个就等于每个元素按照示例2的形式处理,获得最终的输出。
# input
import numpy as np
a = np.arange(20).reshape(5, 4)
i = np.array([[1, 2], [3, 0]])
c = a[:, i]
print(c.shape)
print(c)
# output
(5, 2, 2)
[[[ 1 2]
[ 3 0]]
[[ 5 6]
[ 7 4]]
[[ 9 10]
[11 8]]
[[13 14]
[15 12]]
[[17 18]
[19 16]]]
参考:https://www.numpy.org.cn/user/quickstart.html#%E8%8A%B1%E5%BC%8F%E7%B4%A2%E5%BC%95%E5%92%8C%E7%B4%A2%E5%BC%95%E6%8A%80%E5%B7%A7