先上结论:a+=x是在a的地址上原地修改,而a=a+x是在新的地址上存入了相加得到的值
今晚在给代码debug时遇到一个数组溢出的警告,把中间数据打印出来发现一个数值不可能超过10的np.ndarray竟然短短几步冲到了Inf值,实在诡异。经过一步步排查,最终确定问题出在一个不起眼的“+=”运算符上...特此记下今晚花一小时排出来的坑。
请看下面这段代码:
from collections import deque
max_len = 3
buffer_1 = deque(maxlen=max_len)
buffer_2 = deque(maxlen=max_len)
for i in range(10):
buffer_1.append(np.ones(4))
buffer_2.append(np.ones(4))
a, b = buffer_1[-1], buffer_2[-1]
if len(buffer_1) >= max_len: # 显然,buffer1和2的长度总是相等
for j in range(max_len - 1):
a += buffer_1[j]
b = b + buffer_2[j]
print(str(i), np.all(buffer_1[0] == 1), np.all(buffer_2[0] == 1))
其中,buffer为一个双端队列,我用它主要是为了限制列表长度,使其在添加新元素时自动丢掉超出长度的旧元素。下面每次新添加一个全1向量,并在列表填满后计算所有np.ndarray的和。
在过往认知中,“+=”与“=...+...”完全相同,然鹅!其输出如下:
0 True True
1 True True
2 True True
3 True True
4 False True
5 False True
6 False True
7 False True
8 False True
9 False True
可见,只要使用...=...+...写法的就没问题,而“+=”则在过程中偷偷篡改了buffer的数据。
为什么呢??让我们问问ChatGPT:
这段代码模拟了一个循环,其中使用了两个具有相同最大长度(11)的deque(双端队列),并在每次迭代中向这两个队列中添加一个全1的向量。然后,当队列长度达到其最大长度时,它会遍历队列中的向量,并将它们累加到
a
和b
变量中。你提到的现象发生的原因在于numpy数组的+=
操作符行为。在这段代码中,变量
a
和b
被初始化为队列中最后一个元素的引用。当你使用+=
操作符对a
进行操作时,你实际上是在原地(in-place)修改这个numpy数组。由于a
是buffer_1
中最后一个元素的引用,这意味着你也在修改buffer_1
中存储的数组。因此,当ep
变量的值大于等于11时,buffer_1
的第一个元素已经被累加了10次全1向量,导致其值不再是全1向量。而对于
b
和buffer_2
,虽然每一步看似进行了相同的操作,但实际上b = b + buffer_2[i]
这一步创建了b
的一个新副本,因此原buffer_2
中的向量没有被修改,仍然是全1向量。所以,buffer_2
中的向量即便在循环多次后,其内容还是全1向量。这种行为区别归因于
+=
操作和+
操作的差异。+=
操作符修改了左侧变量指向的内存中的值,而+
操作符生成了一个新的对象。由于buffer_1
中的对象被a
通过+=
直接修改,所以当循环次数超过队列的最大长度时,你会观察到buffer_1
中的元素不再是初始的全1向量。
噢,破案了,原来是因为a = buffer_1[-1]我赋给a的是列表最后一个元素的地址,在a+=x之后,求完和的结果也就保存在了buffer_1[-1]的位置。下一次新元素存入会把这个改变过的元素挤到[-2]的位置,再求和时就会把刚才变大的数组作为加数加进去,如此循环几次之后元素就会变得非常大啦~ 而b=b+x能得到正确的结果则是因为b是在新的地址上操作的,不会改变buffer_2的内容
我们验证一下,在每个i循环的末尾把每一步的buffer_1打印出来看看:
deque([array([1., 1., 1., 1.])], maxlen=3)
deque([array([1., 1., 1., 1.]), array([1., 1., 1., 1.])], maxlen=3)
deque([array([1., 1., 1., 1.]), array([1., 1., 1., 1.]), array([3., 3., 3., 3.])], maxlen=3)
deque([array([1., 1., 1., 1.]), array([3., 3., 3., 3.]), array([5., 5., 5., 5.])], maxlen=3)
deque([array([3., 3., 3., 3.]), array([5., 5., 5., 5.]), array([9., 9., 9., 9.])], maxlen=3)
deque([array([5., 5., 5., 5.]), array([9., 9., 9., 9.]), array([15., 15., 15., 15.])], maxlen=3)
deque([array([9., 9., 9., 9.]), array([15., 15., 15., 15.]), array([25., 25., 25., 25.])], maxlen=3)
deque([array([15., 15., 15., 15.]), array([25., 25., 25., 25.]), array([41., 41., 41., 41.])], maxlen=3)
deque([array([25., 25., 25., 25.]), array([41., 41., 41., 41.]), array([67., 67., 67., 67.])], maxlen=3)
deque([array([41., 41., 41., 41.]), array([67., 67., 67., 67.]), array([109., 109., 109., 109.])], maxlen=3)
可见,第三次的[-1]元素是当时三个全1向量加起来得到,第四次的[-1]元素则是1+3+1=4,第五次是3+5+1=9,第六次是5+9+1=15,以此类推。