Gavin老师Transformer直播课感悟 - 基于Transformer的Rasa 3.x 内核解密之对话策略Policy完整源码详解(二十一)

        本文继续围绕工业级业务对话平台和框架Rasa的对话策略Policy的完整源码进行解析。

一、关于对话策略Policy完整源码解析

  1. Policy的初始化及子类相关源码剖析

    在Policy中通过create方法来创建一个policy graph component,由于Policy继承自GraphComponent,所以必须实现这个接口里定义的抽象方法create:

具体参数说明如下:

config:在具体组件里需要使用自己的config来覆盖default_config

model_storage:模型训练完成后通过持久化操作保存到这个存储中,在进行推理时需要从这个存储中加载模型

resource:相当于一个资源定位器,通过它来查找存储在model_storage中的graph component(即加载component所使用的模型)

execution_context:当前graph运行上下文信息

下面的load方法用于从存储中加载一个组件(component):

        在参数部分除了使用上面create方法中提到的这几个参数之外,还使用了可以传入任意值的参数” **kwargs”,这是因为当前component可以接收来自前面一个或多个graph nodes的输出结果,这就是在Rasa 3.x 架构图(DAG)中所表达的组件之间的依赖关系。

        下面是Policy这个类中实现的load方法,用于从存储中加载一个已训练的policy component,也就是在try里面通过resource加载featurizer,然后再通过cls构建出这个policy的实例,即调用Policy的方法__init__进行初始化操作。如果在存储中找不到持久化保存的featurizer file,那么会抛出相关的异常:

        我们可以看下在__init__方法中调用的_create_featurizer方法,首先会从policy的配置文件中查找是否存在featurizer的配置信息,如果不存在,则调用方法_standard_featurizer创建一个新的featurizer:

   如果能从配置文件查到配置信息,则首先通过lookup_path的值查出可执行函数名称featurizer_func,并且获得state_featurizer_configs,然后传入函数state_featurizer_func获得featurizer_config中关于state_featurizer的配置信息,再传入函数featurizer_func从而获得featurizer:

下面这个方法用于检查调用函数时的参数是否有效:

下面这个方法是把对话状态跟踪器转换为向量(vector)的表示,即把由多个turns构成的对话状态跟踪器转换为能够被机器学习所使用的一个float vector。方法传入的参数包括:

training_trackers:对话状态跟踪器的列表

domain:Domain

precomputations:包含预先计算好的features和attributes

bilou_tagging:是否使用BILOU

返回一个元组对象Tuple,包括:

-一个字典,包含属性(INTENT, TEXT, ACTION_NAME, ACTION_TEXT,ENTITIES, SLOTS, FORM)和对应的features的列表

-每个对话turn中使用的label ids(e.g. action ids)

-一个字典,包含entity type(ENTITY_TAGS)和对应的features的列表

然后基于所支持的数据类型和最大训练数据的配置信息等进行相应的处理,最后返回上面提到的关于state_features, label_ids, entity_tags的三元组Tuple。

下面这个方法是把对话状态跟踪器转换为预测使用的各种状态,其中参数use_text_for_last_user_input表示是否需要使用text还是intent label来对最新的用户输入进行特征提取操作:

这个方法会调用TrackerFeaturizer的方法prediction_states:

Rasa提供的几种policies与Policy之间的关系如下:

-MemoizationPolicy:继承自Policy

-RulePolicy:继承自MemoizationPolicy

-TEDPolicy:继承自Policy

-UnexpecTEDIntentPolicy:继承自TEDPolicy

可以看出,这些policies的公共父类是Policy。

  2.  Policy训练源码详解

       下面的方法train用来训练policy,在Policy中,这是一个抽象方法,意味着需要它的子类来实现具体的逻辑,如果没有实现,则会报错:

      方法参数training_trackers表示来自训练数据(即包含历史对话信息的stories)中的rules和story的对话状态跟踪器。

下面是在RulePolicy中实现的train方法:

下面是Policy的子类TEDPolicy实现的train方法:

首先检查有无训练数据的tracker,因为TEDPolicy的具体功能是由模型TED来实现的,而需要训练数据(stories)来训练模型:

根据支持的数据类型获取训练用的tracker:

调用方法_prepare_for_training,传入tracker,domain,和预计算结果(包括features和attributes),获取模型数据(通过RasaModelData构造)和label ids:

判断model_data是否为空,如果不为空则调用方法run_training进行训练,完成训练后进行持久化操作,即保存到model_storage中,train方法返回resource,在推理时使用这个locator来从model_storage中加载已训练的模型:

   3. Policy预测源码详解

      下面方法用来根据对话状态跟踪器来预测对话机器人的next action,返回的是PolicyPrediction对象,具体参数如下:

tracker:包含到当前对话turn为止的历史对话信息

domain: 模型的domain配置

rule_only_data:与rules相关的slots和loops,这是可选参数

**kwargs:基于当前policy组件在graph中的依赖关系,可以使用不同的输入(指来自所依赖组件的输出)来进行预测

注意这是一个抽象方法,在Policy的各个子类中必须实现,否则会抛出错误:

下面是在TEDPolicy中实现的方法predict_action_probabilities:

从对话跟踪器获取模型数据,调用方法run_inference进行推理获得输出结果:

获取最后一次预测的similarities和confidence的计算结果:

最后调用方法_prediction进行预测:

PolicyPrediction封装了一个policy预测相关的信息,譬如

probabilities:每个action的预测概率

policy_priority:policy的优先级

events:在预测之后policy会使用哪些events来更新tracker的状态

is_end_to_end_prediction:如果是”True”,则直接使用从用户输入信息提取的features来进行预测而不是使用intent

is_no_user_prediction:如果是”True”,则预测时既不会使用从用户输入信息提取的features也不会使用intent,这可以看做是”happy loop paths”的一个样例场景(譬如使用form时)

hide_rule_turn:如果是”True”,表示由不存在于stories中的rules来进行预测

下面的方法根据给定的action,policy和confidence等信息构建出一个PolicyPrediction:

获取预测得到的分数最高的confidence:

格式化对话状态跟踪器的信息以便阅读:

从policy的配置中获取featurizer并且限定只能有一个,返回的是一个可调用的函数” featurizer_func”:

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值