基于
this solution到
Find the row indexes of several values in a numpy array,这是一个基于NumPy的解决方案,内存占用更少,并且在使用大阵列时可能是有益的 –
dims = np.maximum(B.max(0),A.max(0))+1
out = A[~np.in1d(np.ravel_multi_index(A.T,dims),np.ravel_multi_index(B.T,dims))]
样品运行 –
In [38]: A
Out[38]:
array([[1, 1, 1],
[1, 1, 2],
[1, 1, 3],
[1, 1, 4]])
In [39]: B
Out[39]:
array([[0, 0, 0],
[1, 0, 2],
[1, 0, 3],
[1, 0, 4],
[1, 1, 0],
[1, 1, 1],
[1, 1, 4]])
In [40]: out
Out[40]:
array([[1, 1, 2],
[1, 1, 3]])
大型阵列的运行时测试 –
In [107]: def in1d_approach(A,B):
...: dims = np.maximum(B.max(0),A.max(0))+1
...: return A[~np.in1d(np.ravel_multi_index(A.T,dims),\
...: np.ravel_multi_index(B.T,dims))]
...:
In [108]: # Setup arrays with B as large array and A contains some of B's rows
...: B = np.random.randint(0,9,(1000,3))
...: A = np.random.randint(0,9,(100,3))
...: A_idx = np.random.choice(np.arange(A.shape[0]),size=10,replace=0)
...: B_idx = np.random.choice(np.arange(B.shape[0]),size=10,replace=0)
...: A[A_idx] = B[B_idx]
...:
基于广播的解决方案的时间表 –
In [109]: %timeit A[np.all(np.any((A-B[:, None]), axis=2), axis=0)]
100 loops, best of 3: 4.64 ms per loop # @Kasramvd's soln
In [110]: %timeit A[~((A[:,None,:] == B).all(-1)).any(1)]
100 loops, best of 3: 3.66 ms per loop
基于更少内存占用的解决方案的时序 –
In [111]: %timeit in1d_approach(A,B)
1000 loops, best of 3: 231 µs per loop
进一步表现提升
in1d_approach通过将每行作为索引元组来减少每一行.通过使用np.dot引入矩阵乘法,我们可以更有效地做同样的事情,像这样 –
def in1d_dot_approach(A,B):
cumdims = (np.maximum(A.max(),B.max())+1)**np.arange(B.shape[1])
return A[~np.in1d(A.dot(cumdims),B.dot(cumdims))]
我们来测试它与以前的更大的数组 –
In [251]: # Setup arrays with B as large array and A contains some of B's rows
...: B = np.random.randint(0,9,(10000,3))
...: A = np.random.randint(0,9,(1000,3))
...: A_idx = np.random.choice(np.arange(A.shape[0]),size=10,replace=0)
...: B_idx = np.random.choice(np.arange(B.shape[0]),size=10,replace=0)
...: A[A_idx] = B[B_idx]
...:
In [252]: %timeit in1d_approach(A,B)
1000 loops, best of 3: 1.28 ms per loop
In [253]: %timeit in1d_dot_approach(A, B)
1000 loops, best of 3: 1.2 ms per loop