Numpy:连续索引元素赋值失败的问题
前言
最近用numpy做索引赋值时,发现了一个连续索引的问题,记录一下。
连续索引切片
numpy数组可以通过整数index索引(Integer array indexing
)或者元素True False(Boolean array indexing)索引:
import numpy as np
a=np.arange(10).reshape([5, 2])
a[[0, 2, 3]] # integer index
a[[True, True, False, False, False], True] # True False
切片的结果类似元素引用,也就是说,修改切片结果时,原数组的元素也会被修改:
import numpy as np
a=np.arange(10).reshape([5, 2])
a[[0, 2, 3]] = [0, 0]
print(a)
'''
array([[0, 0],
[2, 3],
[0, 0],
[0, 0],
[8, 9]])
'''
这个特性使得修改数组中的元素非常容易。然而,连续索引赋值可能会出现失效。
连续切片问题
上面的例子是单次索引赋值,而连续索引的需求也是存在的。如果上面的单次索引方法改成下面的多次索引:
a[[0, 2, 3]][:]=[0, 0]
print(a)
'''
array([[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9]])
'''
非常神奇的事情发生了,原数组元素没有被更改。
numpy索引:基础索引,高级索引
要理解上面连续索引失败问题,需要先了解numpy的两种索引机制:基础索引和高级索引。
基础索引得到的结果是原数组的一份view
或者一个标量,可以认为是一种浅拷贝或者元素引用机制,修改原数组元素的值时,view
数组也会一起改变,反之亦然。
高级索引得到的结果是原数组的一份copy
,可以认为是深拷贝,新数组不受到原数组的影响。
那么基础索引和高级索引分别是什么呢?简单而言可以如下区分:
基础索引:单元素索引,切片索引,维度索引。简单而言,形如a[0], a[1:-1:1], a[..., :]
等都是基础索引。
高级索引:ndarray索引,非元组序列索引(比如list),元组中含有一个序列的索引。简单而言,形如a[[0, 1]], a[np.array([1])], a[((1, 2), 0)]
的都是高级索引
注意:python中,a[1,2,3]
是a[(1,2,3)]
的简化表示(语法糖),完全等效。
如果一个数组索引中同时存在基础索引和高级索引,则这是一个混合索引,比如a[:, [0]]
是混合索引。而a[:, 0]
则是基础索引。
numpy数组的view, copy属性对连续索引的影响
连续基础索引时,由于view属性会改变原数组,因此连续基础索引不会导致元素赋值失败的问题:
import numpy as np
a=np.arange(10).reshape([5, 2])
a[:1][:1] = 9, 9
print(a)
'''
array([[9, 9],
[2, 3],
[4, 5],
[6, 7],
[8, 9]])
'''
上面这个连续切片索引也是基础索引,第一次切片产生了一个view数组,这个view数组做索引赋值,原数组的元素也一起被修改了。
下面这个连续高级索引,第一次索引产生了一个copy数组,这个copy数组做索引赋值,并不影响原数组的元素。
import numpy as np
a=np.arange(10).reshape([5, 2])
a[[0]][[0]] = [9, 9]
print(a)
'''
array([[1, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9]])
'''
连续混合索引
连续混合索引的结果非常有趣:先基础索引再高级索引,原数组被改变;先高级索引再低级索引,原数组不变,这个很好理解。
import numpy as np
a=np.arange(10).reshape([5, 2])
a[:1][[0]] = [9, 9]
print(a)
'''
array([[9, 9],
[2, 3],
[4, 5],
[6, 7],
[8, 9]])
'''
a[[0]][:1] = [-5, -5]
print(a)
'''
array([[9, 9],
[2, 3],
[4, 5],
[6, 7],
[8, 9]])
'''
单次多维混合索引
无论是a[[0], :]
还是a[:, [0]]
,单次多维混合索引都是可以正常元素赋值的,但是对返回元组的view copy属性有影响,第一维是基础索引时,返回数组是view的,否则是copy的。
可以将a[[0], 1:]
理解为a[:, 1:][[0], :]
,即先从原数组中取对应的列,再从新数组中取对应的行。
判断数组的view, copy属性
可以通过ndarray
的base
属性判断:
a.base is None # True表示是copy,False表示是view