今天使用Numpy时出现了一个bug,经过和同学的讨论才最终得以解决…
1. 问题:
给定数组:
a = np.array([
[1,10,1],
[2,20,2],
[3,30,3],
[2,40,4]
])
要求:如果某一行的第0个元素为2,那么将它的第1个元素改为100
即期望输出为:
[[ 1 10 1]
[ 2 100 2]
[ 3 30 3]
[ 2 100 4]]
2. 错误做法:
mask = a[:, 0] == 2
a[mask][:,1] = 100
[Out]:
[[ 1 10 1]
[ 2 20 2]
[ 3 30 3]
[ 2 40 4]]
3. 正确做法:
mask = a[:, 0] == 2 # [False True False True]
a[mask, 1] = 100
[Out]:
[[ 1 10 1]
[ 2 100 2]
[ 3 30 3]
[ 2 100 4]]
4. 分析:
Numpy推荐用单个括号来访问所有维度(中间用逗号隔开)。
上文的错误做法中,使用了两个括号来访问,这使得在a[mask]这一步中系统就copy了一个新对象出来,更改该新对象并不会对原数组a产生影响(这部分是基于一些实验的猜测)。(其实如果单元素地访问,比如a[0][0]=100,其实是不会有问题的,但是更复杂场景下,比如本问题使用了腌膜mask,就容易出问题)
而正确做法中,使用单括号来访问就避免了这个问题。
最后做个总结就是,Numpy中访问数组要养成“单括号多逗号”的习惯,而不要沿用C++等语言中的“多括号”访问多维数组的习惯。
5. 补(多括号)
多括号的正确使用姿势示例:(注意上文中mask是一维的,下面的mask是二维)
给定二维数组a,要求:将数组a 左上角2*2区域 中 小于3或大于4 的元素置0。
# initialize
a = np.array([
[1,2,3],
[4,5,6],
[7,8,9]
])
# mask
aa = a[:2, :2]
mask = np.bitwise_or(aa<3,aa>4) # 2*2的掩膜
# result
a[:2, :2][mask] = 0
print(a)
[Out]:
[[0 0 3]
[4 0 6]
[7 8 9]]