"""
Calculate knn by multi-process
"""
import time
import numpy as np
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing
import ctypes
from contextlib import closing
def _knn(args):
"""
Calculate knn for queries with numeric_id of `features` in `numeric_id_range`
Args:
args: tuple of args
args[0]: feats_shape
args[2]: numeric_id_range, range object of numeric_id, len(numeric_id_range) = bs
args[3]: n_knn, number of nearest neighbors as graph adjacent, default 8
Inherited:
_feats_shd: all features, shared memory numpy.ndarray
_norm_shd: l2-norm of `features`, shared memory numpy.ndarray
Returns:
numeric_id_range: input index range
topk_ids: local knn result, (bs, n_knn) numpy.ndarray
"""
# unzip args
feats_shape, numeric_id_range, n_knn = args
# create numpy view upon shared memory
feats_shd = np.frombuffer(_shm_feats, dtype=np.float32).reshape(feats_shape)
# compute knn from cosine similarity
cos_sim = np.dot(feats_shd[numeric_id_range], feats_shd.T)
topk_ids = np.argpartition(cos_sim, n_knn, axis=1)[:, -n_knn:] # largest #n_knn cosine similarity indices
return numeric_id_range, topk_ids
def _pool_init(shm_feats):
"""Emit shared memory as global variables
"""
global _shm_feats
_shm_feats = shm_feats
def multiprocess_knn(feats, n_knn, n_workers, batch_size):
"""Compute knn by multiprocessing
Args:
feats: normalized input features, numpy.ndarray
n_knn: number of nearest neighbor in knn algorithm
n_workers: number of processes
batch_size: number of a batch
Returns:
numpy.ndarray for nearest neighbor indices
"""
since = time.time()
# shape parameters
total = len(feats)
feats_shape = feats.shape
# create shared memory (Read-Only, so without synchronization)
shm_feats = multiprocessing.RawArray(ctypes.c_float, feats.size)
del feats # save memory
batch_num = total // batch_size + 1
batch_cnt = 0 # completed batch counter
total_topk_ids = np.zeros(shape=(total, n_knn), dtype=np.int32) # pre-defined result
# multiprocessing.pool
with closing(multiprocessing.Pool(processes=n_workers, initializer=_pool_init, initargs=(shm_feats,))) as pool:
for id_range, local_topk in pool.imap_unordered(_knn, [(feats_shape, range(batch_size * batch, min(batch_size * (batch+1), total)), n_knn) for batch in range(batch_num)]):
total_topk_ids[id_range] = local_topk
batch_cnt += 1
print('{}/{}'.format(batch_cnt, batch_num), flush=True) # write to file without buffering
duration = int(time.time() - since)
print('batch_size: {}, time: {}h {}min {}s'.format(batch_size, duration//3600, duration//60, duration%60))
return total_topk_ids