召回池
我的想法是,输入形式为batch时模型总计算时间应当比一个一个喂小不少,因此可以建立一个比如200为大小的召回池,最多每隔0.5s送入模型进行计算,当waiting list已经到达了200,立即执行计算,重置定时任务。
消息队列
python的queue库是一个线程安全的队列,可以用作消息队列。
其基本用法参考:
Python之queue模块以及生产消费者模型
先写一个类包装
class Pack:
def __init__(self, flag=False, _id=None, seq=None, hidden=None):
self.flag = flag # True means activation signal
self._id = _id
self.seq = seq
if self.hidden:
self.hidden = hidden
else:
self.hidden = [0] * setting['hidden_dim']
flag为True时,意味着这是一个有意义的输入;False时用作Timer对队列的唤醒令牌。
我们的召回池,wating_list就是消息队列:
class BatchPool:
def __init__(self, pool_size=200, timeout=0.5):
self.timeout = timeout
self.uid2idx = dict()
self.res = dict()
self.waiting_list = Queue(pool_size)
self.data = {'seq': list(), 'lengths': list(), 'hidden': list()}
self.get_recalls()
def clear(self):
for key in self.data.keys():
self.data[key].clear()
def get_recalls(self):
while True:
item = self.waiting_list.get()
if not item.flag:
self.uid2idx[item._id] = len(self.data['seqs'])
self.data['seq'].append(item.seq)
self.data['hidden'].append(item.hidden)
self.data['lengths'].append(len(item.seq))
if item.flag or self.waiting_list.full():
top_items = recommander.recall(**self.data)
for uid, idx in self.uid2idx:
self.res[uid] = top_items[idx]
self.clear()
self.waiting_list.task_done()
当wating_list为空,消费者会阻塞在get()处。最后召回完成后,task_done()使所有join()挂起的线程唤醒。
但现在有个问题,如何保证所有发起请求者都拿到自己所需的数据后,召回池才开始新一轮召回任务?这里生产者变成了召回池,消费者变成了请求线程。
可以,让生产者向消费者发n次消息,然后等在一个condition上,消费者取走到之后判断队列是否为空,空则唤醒生产者。但这样有个问题,如果有个消费者线程挂掉了,那么永远没有唤醒生产者的第n个消费者。
因此需要让消费者通知生产者,生产者等n次通知继续,或者timeout继续。当然,这个方式还是存在问题,如果timeout太大,下一轮迟迟不能开始,如果太小,有可能来不及让存活的消费者拿走。不过进程挂掉的概率很小,所以可以暂不考虑。
def get_recalls(self):
while True:
item = self.waiting_list.get()
if item.flag:
self.uid2idx[item._id] = len(self.data['seqs'])
self.data['seq'].append(item.seq)
self.data['hidden'].append(item.hidden)
self.data['lengths'].append(len(item.seq))
if not item.flag and len(self.uid2idx) > 0 or self.waiting_list.full():
top_items = recommander.recall(**self.data)
for uid, idx in self.uid2idx:
self.res[uid] = top_items[idx]
self.waiting_list.task_done()
for i in range(len(self.uid2idx)):
try:
self.notify_queue.get()
except Exception as e:
print(e)
self.clear()
def ask_for_recall(self, _id, seq, hidden):
self.waiting_list.put(Pack(True, _id, seq, hidden))
self.waiting_list.join()
self.notify_queue.put(1)
定时任务
python提供了Timer定时器,但是只能按固定的interval开启。
因此使用APScheduler库进行调度。
self.shed = BackgroundScheduler()
self.job = self.shed.add_job(self.awake, 'interval', seconds=timeout)
self.shed.start()
self.get_recalls()
def awake(self):
if len(self.uid2idx) > 0:
self.waiting_list.put(Pack())
def get_recalls(self):
while True:
item = self.waiting_list.get()
if item.flag:
self.uid2idx[item._id] = len(self.data['seqs'])
self.data['seq'].append(item.seq)
self.data['hidden'].append(item.hidden)
self.data['lengths'].append(len(item.seq))
if not item.flag and len(self.uid2idx) > 0 or len(self.uid2idx) >= self.pool_size:
self.job.pause()
top_items = recommander.recall(**self.data)
for uid, idx in self.uid2idx:
self.res[uid] = top_items[idx]
self.waiting_list.task_done()
for i in range(len(self.uid2idx)):
try:
self.notify_queue.get()
except Exception as e:
print(e)
self.clear()
self.job.resume()
可能遇到的错误 APScheduler: LookupError: No trigger by the name “interval” was found
优化与Debug
后来想了想,消息队列的最大大小不能是200,必须比他大,因为在消费者召回计算中,还有新的生产者向消息队列中添加。设为无限大?问题是在新生产者添加后进入等待,消费者计算完之后notify所有的生产者,新生产者被提前唤醒了。
所以设置200比较安全,但是如果定时器一直发消息,模型一直在计算,那么有可能消息队列里全都是唤醒令牌。所以必须在计算时让定时器暂停,这样就没问题了。
另外,get_recalls不能在主线程调用,不然会一直等在get()处,要加一个线程运行它
recommander = Recommander(setting['model'])
pool = BatchPool()
thread = threading.Thread(target=pool.get_recalls)
thread.start()
取
因为api调用的进程需要等待召回池返回,需要挂起。可以使用用户级别的协程来管理。需要使用async关键字声明异步方法,使用await等待返回。
async def ask_for_recall(self, _id, seq, hidden):
self.waiting_list.put(Pack(True, _id, seq, hidden))
self.waiting_list.join()
iids = self.res[_id]['iids']
emb = self.res[_id]['emb']
self.notify_queue.put(1)
return iids, emb
async def get_rec_items(id, emb):
iids, emb = await pool.ask_for_recall(id, emb)
return iids, emb
@api_view(['GET'])
def recommand(requests):
emb = None
_id = requests.COOKIES.get('_id')
user = None
if _id:
user = User.objects.only(['emb']).with_id(ObjectId('_id'))
if not user:
return JsonResponse({'err': '用户不存在,请重新登录'},
json_dumps_params={'ensure_ascii': False})
if user.get('emb'):
emb = user['emb']
seq = requests.COOKIES.get('sess')
loop = asyncio.get_event_loop()
future = asyncio.ensure_future(get_rec_items('_id', emb))
loop.run_until_complete(future)
iids, emb = future.result()
if user:
user.emb = emb
user.save()
items = Details.objects.filter(sourceId__in=iids)
return JsonResponse({'rec': items},
json_dumps_params={'ensure_ascii': False})
loop = asyncio.get_event_loop()
future = asyncio.ensure_future(get_rec_items('_id', emb))
loop.run_until_complete(future)
iids, emb = future.result()
这是一般调用的关键代码,但发起的请求线程并非主线程,get_event_loop会报错,建议用run()来执行,或者new一个loop
future = asyncio.ensure_future(get_rec_items(_id, seq, emb))
asyncio.run(future)
API
/api/recommand
无参数
待办事项
现在如果sess为空不推荐,应当通过计算item(余弦)相似度的方式推荐。