本文继续围绕工业级业务对话平台和框架Rasa解析以下核心源码:
- GraphNode,Rasa把DAG图中的每个组件都视作一个graph node,譬如DIETClassifier,TEDPolicy等
- ExecutionContext
- GraphNodeHook
- GraphModelConfiguration
六、定制Graph Components
5. 关于graph.py核心源码解析
1) GraphNode解析
GraphNode的作用是在DAG图中实例化和运行一个graph component。GraphNode是GraphComponent的封装(wrapper),从而可以在一个graph的上下文里执行这个graph component。 GraphNode负责在正确的时间实例化这个component,收集从这个graph node的parent nodes获取的inputs,执行这个component的run函数,然后根据依赖关系把函数执行后的output进行传递。
方法__init__用于实例化GraphNode:
参数说明如下:
- node_name:在graph schema里定义的node name
- component_class:需要被实例化和运行的component class,即接口GraphComponent的组件实现类,譬如WhitespaceTokenizer
- constructor_name:用来实例化component的方法名
- component_config:传递给component的配置
- fn_name:当一个graph node执行时,需要运行被实例化的graph component的函数的名称
- inputs:是一个map,key为input name,value为提供inputs的parent node的name,格式如[input name, parent node name]
- eager:bool类型,True:这个graph node会被立即实例化,False:仅在run之前才实例化
- model_storage:graph components需要从哪个model storage进行load以及持久化操作时把graph components保存到哪个model storage
- resource:是否根据这个resource从指定的model storage加载(load) graph components
- execution_context:关于当前graph运行环境的信息
- hooks:在一个graph node执行前和执行后被调用的代码逻辑
在成员属性”_fn”赋值时,Callable是通过graph component class和fn_name这两个参数传入方法getattr而获得:
需要确保返回的是一个可以调用的函数,Callable来自于Python的library “typing”:
方法_load_component中的参数” **kwargs: Any”表示可以传入任何的[key, value]形式的参数:
方法__call__会调用这个graph node所对应的GraphComponent的run 方法(即参数fn_name对应的方法名称):
这个方法有个参数*inputs_from_previous_nodes: Tuple[Text, Any],下面是调用graph node的__call__方法的示例(Tuple数量为2):
在__call__方法中调用GraphNode的成员方法_run_before_hooks做graph node运行前的处理:
在方法_run_before_hooks中会调用GraphNodeHook的方法on_before_node进行处理:
在方法__call__的最后调用了成员方法_run_after_hooks在这个graph node运行完成后进行处理:
2) ExecutionContext解析
ExecutionContext是一个数据封装类,主要包含以下信息:
- graph_schema:用于描述一个graph中各个graph component的信息,这里使用field(repr=False),表示输出excution context信息时不包括graph schema
- model_id:唯一标识模型的ID,如果存在多个模型(譬如有两个DIETClassifier),需要用ID来标识
- should_add_diagnostic_data:是否需要在训练时添加诊断信息
- is_finetuning:是否需要进行fine-tuning
- node_name:一个graph node的名称
3) GraphSchema解析
GraphSchema是一个数据封装类,一个graph中的所有graph nodes组成了一个dictionary,即对应graph schema,其格式可以是JSON格式或者其它的格式。
下面是把GraphSchema对象转换为JSON格式(序列化):
下面是通过resource查找并把存储的graph schema信息转换为GraphSchema对象(反序列化):
GraphSchema会使用SchemaNode,它用于表示一个graph中的一个node。
包含以下属性:
具体说明如下:
- needs:在函数”fn”或者” constructor_name”(如果eager为False,则使用constructor_name)中,哪些参数会由这个node的parent node传入
- uses:由哪个graph component决定这个graph node的执行行为
- constructor_name:用于实例化这个graph component的constructor的名称
- fn:当一个graph node执行时,需要运行被实例化的graph component的函数的名称
- config:这个graph node的用户配置信息
- eager:bool类型,True:这个graph node会在graph运行之前被实例化,False:在graph run时才实例化(懒加载)
- is_target:bool类型,True:graph component训练结果总是需要被添加到model archive以便推理时可以使用数据
- is_input:bool类型,True:表示这个graph node会运行
- resource:如果设定,表示会根据指定的resource load这个graph node
4) GraphNodeHook解析
这个类包含了在一个graph node运行前和运行后需要执行的逻辑。它定义了两个抽象方法on_before_node和on_after_node。通过这两个方法的具体实现逻辑可以对一个graph node的运行生命周期进行控制,譬如graph node运行前的初始化处理,运行后的保存及清理处理等。
具体参数说明:
- node_name:即将运行的graph node的名称
- execution_context:当前graph运行的上下文
- config:这个graph node的配置信息
- received_inputs:参数名和参数值的mapping
具体参数说明:
- node_name:运行完成的graph node的名称
- execution_context:当前graph运行的上下文
- config:这个graph node的配置信息
- output:这个graph node运行完成后的输出信息
- input_hook_data:从” on_before_node”方法返回的数据
5) GraphModelConfiguration解析
这个类封装了graph运行时需要使用的模型配置信息。
6) 关于graph node的测试
Rasa提供了相关的测试源码,以下是关于graph node测试的一个方法: