本文继续围绕工业级业务对话平台和框架Rasa对如何处理业务对话系统中来自用户的unexpected intent的场景从核心组件UnexpecTEDIntentPolicy的源码层面进行解析。
一、关于UnexpecTEDIntentPolicy组件算法源码分析
- UnexpecTEDIntentPolicy源码分析
Rasa 3.x把所有的组件都抽象为graph component,并构建出各个graph component之间的依赖关系,这种依赖关系可以表达数据的生产者和消费者模型,Rasa 3.x基于DAG图的架构实现了系统基础架构和模型架构的分离,这样就可以使开发者只考虑在graph中使用什么样的模型,譬如通过下面的方法来指定UnexpecTEDIntentPolicy组件所使用的模型IntentTED,它会负责对数据进行处理:
通过方法run_training把输入数据转换为模型需要的特征数据vectors:
参数说明:
model_data:转换后的特征数据vectors
label_ids:与model_data里的数据相对应,如果为空,则会抛出RasaCoreException,因为模型在post training时会使用这些ids进行比较
在方法run_training内部会调用父类(TEDPolicy)的方法run_training:
然后会调用关键方法fit,这实际上是调用tensorflow的Keras model的fit方法:
在UnexpecTEDIntentPolicy的方法run_training里还调用了方法compute_label_quantiles_post_training,这个方法用于计算是否触发” action_unlikely_intent”的分数,在推理时针对每一个label都会计算多个分数,再根据” tolerance”设定的值来决定触发action ” action_unlikely_intent”的 threshold:
在这个方法里会调用IntentTED的核心方法run_bulk_inference,根据输入数据RasaModelData调用模型进行预测,在方法最后调用了RasaModel的方法run_inference:
在方法参数里,batch_size可以是int,也可以是一个List,方法返回类型是Dict。在方法里,调用了create_data_generators来获得data_generator,data_generator是由输入和输出构成的一个二元组,根据iter方法把data_generator转换为data_iterator。在while循环里,可以看到调用了next方法来获取一条数据,然后调用方法_rasa_predict进行预测。当遍历完数据后,会抛出异常StopIteration。根据batch_in产出每一条数据的batch_out后,需要调用方法_merge_batch_outputs进行合并。
在方法_rasa_predict里,会根据各种条件判断进行相应的处理:
关于iterator,generator,container之间的关系和使用,可以参考下面这个图,根据iter方法把container中的一个iterable对象转换为iterator,然后再通过方法next(这里使用了懒加载,在大数据量的情况下特别有用)获取iterator中的值:
方法_prepare_data_for_prediction用于把训练数据转换为可以用于模型预测的数据:
方法predict_action_probabilities根据tracker预测next action:
获取最后一次的预测来检查用户意图是否与对话上下文匹配,如果预测相似度低于这个用户意图对应的计算得到的threshold,那么就认为这个意图是unlikely的:
根据条件判断query intent是否低于threshold来返回结果:
方法_prediction返回PolicyPrediction,PolicyPrediction是存储一个Policy预测的相关信息,其中:
-probabilities表示每个action的概率
-policy_priority表示policy的优先级,机器学习的policy会比基于rule的优先级低(因为rule是确定的)
-is_end_to_end_prediction表示是否是end-to-end learning的方式
方法_should_check_for_intent会检查是否需要产生action “action_unlikely_intent”,这里的参数domain是指Rasa基于配置文件运行的对话机器人的实例:
方法_pick_thresholds用于计算每个label id的threshold:
2. IntentTED源码分析
IntentTED使用TED模型架构,但是它的作用是用来预测用户意图而不是像TED那样去预测next action:
方法dot_product_loss_layer用于获取点积计算loss layer, 这里使用了针对多意图的类MultiLabelDotProductLoss:
方法_get_labels_embed根据label ids从缓存中获取之前计算过的label的embeddings,而不是重新去处理这些labels,这样可以节省时间,加快推理时的系统响应速度:
方法run_bulk_inference根据输入信息进行模型预测: