Gavin老师Transformer直播课感悟 - 基于Transformer的Rasa 3.x 内核解密之UnexpecTEDIntentPolicy算法源码及IntentTED详解(十六)

本文继续围绕工业级业务对话平台和框架Rasa对如何处理业务对话系统中来自用户的unexpected intent的场景从核心组件UnexpecTEDIntentPolicy的源码层面进行解析。

一、关于UnexpecTEDIntentPolicy组件算法源码分析

  1. UnexpecTEDIntentPolicy源码分析

      Rasa 3.x把所有的组件都抽象为graph component,并构建出各个graph component之间的依赖关系,这种依赖关系可以表达数据的生产者和消费者模型,Rasa 3.x基于DAG图的架构实现了系统基础架构和模型架构的分离,这样就可以使开发者只考虑在graph中使用什么样的模型,譬如通过下面的方法来指定UnexpecTEDIntentPolicy组件所使用的模型IntentTED,它会负责对数据进行处理:

通过方法run_training把输入数据转换为模型需要的特征数据vectors:

参数说明:

model_data:转换后的特征数据vectors

label_ids:与model_data里的数据相对应,如果为空,则会抛出RasaCoreException,因为模型在post training时会使用这些ids进行比较

在方法run_training内部会调用父类(TEDPolicy)的方法run_training:

然后会调用关键方法fit,这实际上是调用tensorflow的Keras model的fit方法:

在UnexpecTEDIntentPolicy的方法run_training里还调用了方法compute_label_quantiles_post_training,这个方法用于计算是否触发” action_unlikely_intent”的分数,在推理时针对每一个label都会计算多个分数,再根据” tolerance”设定的值来决定触发action ” action_unlikely_intent”的 threshold:

在这个方法里会调用IntentTED的核心方法run_bulk_inference,根据输入数据RasaModelData调用模型进行预测,在方法最后调用了RasaModel的方法run_inference:

        在方法参数里,batch_size可以是int,也可以是一个List,方法返回类型是Dict。在方法里,调用了create_data_generators来获得data_generator,data_generator是由输入和输出构成的一个二元组,根据iter方法把data_generator转换为data_iterator。在while循环里,可以看到调用了next方法来获取一条数据,然后调用方法_rasa_predict进行预测。当遍历完数据后,会抛出异常StopIteration。根据batch_in产出每一条数据的batch_out后,需要调用方法_merge_batch_outputs进行合并。

在方法_rasa_predict里,会根据各种条件判断进行相应的处理:

        关于iterator,generator,container之间的关系和使用,可以参考下面这个图,根据iter方法把container中的一个iterable对象转换为iterator,然后再通过方法next(这里使用了懒加载,在大数据量的情况下特别有用)获取iterator中的值:

方法_prepare_data_for_prediction用于把训练数据转换为可以用于模型预测的数据:

方法predict_action_probabilities根据tracker预测next action:

    获取最后一次的预测来检查用户意图是否与对话上下文匹配,如果预测相似度低于这个用户意图对应的计算得到的threshold,那么就认为这个意图是unlikely的:

  

              根据条件判断query intent是否低于threshold来返回结果:

 方法_prediction返回PolicyPrediction,PolicyPrediction是存储一个Policy预测的相关信息,其中:

              -probabilities表示每个action的概率

              -policy_priority表示policy的优先级,机器学习的policy会比基于rule的优先级低(因为rule是确定的)

              -is_end_to_end_prediction表示是否是end-to-end learning的方式

        方法_should_check_for_intent会检查是否需要产生action “action_unlikely_intent”,这里的参数domain是指Rasa基于配置文件运行的对话机器人的实例:

         方法_pick_thresholds用于计算每个label id的threshold:

   2. IntentTED源码分析

     IntentTED使用TED模型架构,但是它的作用是用来预测用户意图而不是像TED那样去预测next action:

方法dot_product_loss_layer用于获取点积计算loss layer, 这里使用了针对多意图的类MultiLabelDotProductLoss:

方法_get_labels_embed根据label ids从缓存中获取之前计算过的label的embeddings,而不是重新去处理这些labels,这样可以节省时间,加快推理时的系统响应速度:

方法run_bulk_inference根据输入信息进行模型预测:

                                                                                                               

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值