Gavin老师Transformer直播课感悟 - 基于Transformer的Rasa Internals解密之框架核心graph.py源码完整解析及测试(十)

        本文继续围绕工业级业务对话平台和框架Rasa解析以下核心源码:

  1. GraphNode,Rasa把DAG图中的每个组件都视作一个graph node,譬如DIETClassifier,TEDPolicy等
  2. ExecutionContext
  3. GraphNodeHook
  4. GraphModelConfiguration

六、定制Graph Components

5. 关于graph.py核心源码解析

1) GraphNode解析

   GraphNode的作用是在DAG图中实例化和运行一个graph component。GraphNode是GraphComponent的封装(wrapper),从而可以在一个graph的上下文里执行这个graph component。 GraphNode负责在正确的时间实例化这个component,收集从这个graph node的parent nodes获取的inputs,执行这个component的run函数,然后根据依赖关系把函数执行后的output进行传递。

方法__init__用于实例化GraphNode:

参数说明如下:

  1. node_name:在graph schema里定义的node name
  2. component_class:需要被实例化和运行的component class,即接口GraphComponent的组件实现类,譬如WhitespaceTokenizer
  3. constructor_name:用来实例化component的方法名
  4. component_config:传递给component的配置
  5. fn_name:当一个graph node执行时,需要运行被实例化的graph component的函数的名称
  6. inputs:是一个map,key为input name,value为提供inputs的parent node的name,格式如[input name, parent node name]
  7. eager:bool类型,True:这个graph node会被立即实例化,False:仅在run之前才实例化
  8. model_storage:graph components需要从哪个model storage进行load以及持久化操作时把graph components保存到哪个model storage
  9. resource:是否根据这个resource从指定的model storage加载(load) graph components
  10. execution_context:关于当前graph运行环境的信息
  11. hooks:在一个graph node执行前和执行后被调用的代码逻辑

在成员属性”_fn”赋值时,Callable是通过graph component class和fn_name这两个参数传入方法getattr而获得:

需要确保返回的是一个可以调用的函数,Callable来自于Python的library “typing”:

方法_load_component中的参数” **kwargs: Any”表示可以传入任何的[key, value]形式的参数:

方法__call__会调用这个graph node所对应的GraphComponent的run 方法(即参数fn_name对应的方法名称):

这个方法有个参数*inputs_from_previous_nodes: Tuple[Text, Any],下面是调用graph node的__call__方法的示例(Tuple数量为2):

在__call__方法中调用GraphNode的成员方法_run_before_hooks做graph node运行前的处理:

在方法_run_before_hooks中会调用GraphNodeHook的方法on_before_node进行处理:

在方法__call__的最后调用了成员方法_run_after_hooks在这个graph node运行完成后进行处理:

2) ExecutionContext解析

    ExecutionContext是一个数据封装类,主要包含以下信息:

  1. graph_schema:用于描述一个graph中各个graph component的信息,这里使用field(repr=False),表示输出excution context信息时不包括graph schema
  2. model_id:唯一标识模型的ID,如果存在多个模型(譬如有两个DIETClassifier),需要用ID来标识
  3. should_add_diagnostic_data:是否需要在训练时添加诊断信息
  4. is_finetuning:是否需要进行fine-tuning
  5. node_name:一个graph node的名称

3) GraphSchema解析

   GraphSchema是一个数据封装类,一个graph中的所有graph nodes组成了一个dictionary,即对应graph schema,其格式可以是JSON格式或者其它的格式。

下面是把GraphSchema对象转换为JSON格式(序列化):

下面是通过resource查找并把存储的graph schema信息转换为GraphSchema对象(反序列化):

GraphSchema会使用SchemaNode,它用于表示一个graph中的一个node。

包含以下属性:

具体说明如下:

  1. needs:在函数”fn”或者” constructor_name”(如果eager为False,则使用constructor_name)中,哪些参数会由这个node的parent node传入
  2. uses:由哪个graph component决定这个graph node的执行行为
  3. constructor_name:用于实例化这个graph component的constructor的名称
  4. fn:当一个graph node执行时,需要运行被实例化的graph component的函数的名称
  5. config:这个graph node的用户配置信息
  6. eager:bool类型,True:这个graph node会在graph运行之前被实例化,False:在graph run时才实例化(懒加载)
  7. is_target:bool类型,True:graph component训练结果总是需要被添加到model archive以便推理时可以使用数据
  8. is_input:bool类型,True:表示这个graph node会运行
  9. resource:如果设定,表示会根据指定的resource load这个graph node

4) GraphNodeHook解析

        这个类包含了在一个graph node运行前和运行后需要执行的逻辑。它定义了两个抽象方法on_before_node和on_after_node。通过这两个方法的具体实现逻辑可以对一个graph node的运行生命周期进行控制,譬如graph node运行前的初始化处理,运行后的保存及清理处理等。

具体参数说明:

  1. node_name:即将运行的graph node的名称
  2. execution_context:当前graph运行的上下文
  3. config:这个graph node的配置信息
  4. received_inputs:参数名和参数值的mapping

具体参数说明:

  1. node_name:运行完成的graph node的名称
  2. execution_context:当前graph运行的上下文
  3. config:这个graph node的配置信息
  4. output:这个graph node运行完成后的输出信息
  5. input_hook_data:从” on_before_node”方法返回的数据

5) GraphModelConfiguration解析

        这个类封装了graph运行时需要使用的模型配置信息。

6) 关于graph node的测试

        Rasa提供了相关的测试源码,以下是关于graph node测试的一个方法:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值