首先创建一些数据:import numpy as np
np.random.seed(1)
list_of_np_1D = np.random.randint(0, 5, size=(500, 6))
np_2D = np.random.randint(0, 5, size=(20, 6))
运行代码:
^{pr2}$
输出:CPU times: user 161 ms, sys: 2 ms, total: 163 ms
Wall time: 167 ms
这是一个加速版本,它使用broadcast,.view()方法将数据类型转换为字符串,调用set()将字符串转换回数组:%%time
r = list_of_np_1D[:, None, :] + np_2D[None, :, :]
stype = "S%d" % (r.itemsize * np_2D.shape[1])
fill_set2 = set(r.ravel().view(stype).tolist())
res2 = np.zeros(len(fill_set2), dtype=stype)
res2[:] = list(fill_set2)
res2 = res2.view(r.dtype).reshape(-1, np_2D.shape[1])
输出:CPU times: user 13 ms, sys: 1 ms, total: 14 ms
Wall time: 14.6 ms
检查结果:np.all(res1[np.lexsort(res1.T), :] == res2[np.lexsort(res2.T), :])
您还可以使用lexsort()删除重复数据:%%time
r = list_of_np_1D[:, None, :] + np_2D[None, :, :]
r = r.reshape(-1, r.shape[-1])
r = r[np.lexsort(r.T)]
idx = np.where(np.all(np.diff(r, axis=0) == 0, axis=1))[0] + 1
res3 = np.delete(r, idx, axis=0)
输出:CPU times: user 13 ms, sys: 3 ms, total: 16 ms
Wall time: 16.1 ms
检查结果:np.all(res1[np.lexsort(res1.T), :] == res3)