四、Train —— 准备工作
当我们执行 rasa train
命令后,实际会进入到函数 rasa.cli.train.train()
中,这相当于模型训练的主函数,接下来,我们将对该过程进行拆解,看看 rasa train
背后,都发生了什么。
# rasa.cli.train.train()
def train():
domain = rasa.cli.utils.get_validated_path(
args.domain, "domain", DEFAULT_DOMAIN_PATH, none_is_valid=True
)
# 'rasa-2.2.1/rasa-demo/domain.yml'
config = _get_valid_config(args.config, CONFIG_MANDATORY_KEYS)
# 'rasa-2.2.1/rasa-demo/config.yml'
training_files = [
rasa.cli.utils.get_validated_path(
f, "data", DEFAULT_DATA_PATH, none_is_valid=True
)
for f in args.data
] # ['rasa-2.2.1/rasa-demo/data']
在训练前需要检测有效的 domain.yml
,config.yml
以及 NLU 和 Core 模型训练的数据文件位置等。
对于 domain 文件,一般都是指默认的 domain.yml
文件;
对于 config 文件,通常也是指默认的 config.yml
文件,对于合法的 config 文件,其必须要配置的 key 有: [“language”],如果发现在 config 文件中缺失该字段,则会报异常说明。“language” 指的是你训练语料面向的语种,对于英文用 “en” 表示,中文用 “zh” 表示,更多语种选择可以参照语言支持。
对于数据文件,通常也是指默认的 “data/” 目录,这里通常会存放有关 NLU 和 Core 训练相关的所有数据文件。
# rasa.train.train()
def train(
domain: Text, # 'rasa-2.2.1/rasa-demo/domain.yml'
config: Text, # 'rasa-2.2.1/rasa-demo/config.yml'
training_files: Union[Text, List[Text]], # ['rasa-2.2.1/rasa-demo/data']
output: Text = DEFAULT_MODELS_PATH, # 'models'
dry_run: bool = False, # False
force_training: bool = False, # False
fixed_model_name: Optional[Text] = None, # None
persist_nlu_training_data: bool = False, # False
core_additional_arguments: Optional[Dict] = None, # {'augmentation_factor': 50, 'debug_plots': False}
nlu_additional_arguments: Optional[Dict] = None, # {'num_threads': 1}
loop: Optional[asyncio.AbstractEventLoop] = None, # None
model_to_finetune: Optional[Text] = None, # None
finetuning_epoch_fraction: float = 1.0, # 1.0
) -> TrainingResult:
""" 在 async 循环中运行 Rasa Core 和 NLU 训练。
Args:
domain: domain 文件的位置
config: Core 和 NLU 配置文件的位置
training_files: Core 和 NLU 训练数据的位置
output: 模型输出的位置
dry_run: 如果设置为 `True`,则不进行训练,并打印是否需要进行训练的有关信息
force_training: 如果设置为 `True` 即使数据未改变也会强制重新训练模型
fixed_model_name: 模型存储用的名字
persist_nlu_training_data: 如果设置为 `True`,则 NLU 训练数据应与模型一起持久化存储。
core_additional_arguments: Core 部分的额外训练参数。
nlu_additional_arguments: 附加的训练参数将传递给每个 NLU 组件中。
loop: 运行协同程序的可选 EventLoop。
model_to_finetune: 模型的可选路径,该路径应进行微调,或者在使用最新训练模型的情况下提供一个目录。
finetuning_epoch_fraction: 当前在模型配置中指定的用于微调的训练时段的分数。
Returns:
An instance of `TrainingResult`.
"""
这里的代码涉及到 “async” 异步操作,如果对此感到陌生,可以暂且忽略,暂时先把它们都看做普通函数看待即可。
此时,进入到函数 rasa.train.train_async()
中,首先做的第一件事就是获取文件加载器
file_importer = TrainingDataImporter.load_from_config(
config, domain, training_files
)
通过函数调用名可以看到,数据加载器的读取是从配置文件中读取的,具体而言,可以在 config.yml
文件下通过 importers
键进行配置,可以配置自定义的数据加载器,但通常而言,我们是不进行此项配置的,即,使用的是默认的 “RasaFileImporter”,另外还有 “MultiProjectImporter”,不过目前还处于试验阶段。
class RasaFileImporter(TrainingDataImporter): # 默认的训练文件加载器
"""Default `TrainingFileImporter` implementation."""
def __init__(
self,
config_file: Optional[Text] = None, # 'rasa-2.2.1/rasa-demo/config.yml'
domain_path: Optional[Text] = None, # 'rasa-2.2.1/rasa-demo/domain.yml'
training_data_paths: Optional[Union[List[Text], Text]] = None, # [rasa-2.2.1/rasa-demo/data']
training_type: Optional[TrainingType] = TrainingType.BOTH, # NLU 与 core 都参与训练
):
self._domain_path = domain_path # 'rasa-2.2.1/rasa-demo/domain.yml'
self._nlu_files = rasa.shared.data.get_data_files(
training_data_paths, rasa.shared.data.is_nlu_file # data/,符合 nlu 条件的文件
)
self._story_files = rasa.shared.data.get_data_files(
training_data_paths, rasa.shared.data.is_story_file # data/,符合 story 条件的文件
)
self.config = autoconfig.get_configuration(config_file, training_type) # 读取 config,并对缺失部分自动配置
在这里,需要明确 domain 文件的位置,并根据 Core 还是 NLU 部分将训练数据分开,此外对于配置文件中缺失的配置项进行自动配置。具体实现上,首先会遍历 “data/” 目录下的所有文件,读取文件内容后,猜测其可能的编写格式(如 Markdown 方式或者是 YAML 等),如果一个文件中包含 “nlu” 或 “responses” 键,则认为是 NLU 部分的训练数据文件,详见 rasa.shared.nlu.training_data.formats.rasa_yaml.RasaYAMLReader.is_yaml_nlu_file()
;同理,如果一个文件下包含 “stories” 或 “rules” 键,则认为是 story 文件,详见 rasa.shared.core.training_data.story_reader.yaml_story_reader.YAMLStoryReader.is_stories_file()
。
因此,nlu 文件共有 8 个:
['data/nlu/chitchat.yml', 'data/nlu/faq.yml', 'data/nlu/general.yml', 'data/nlu/lookups/location.yml', 'data/nlu/lookups/products.yml', 'data/nlu/nlu.yml', 'data/nlu/out_of_scope.yml', 'data/nlu/responses/responses.yml']
core 文件共有 15 个:
['data/rules/fallback.yml', 'data/rules/feedback.yml', 'data/rules/forms.yml', 'data/rules/rules.yml', 'data/stories/canthelp.yml', 'data/stories/chitchat.yml', 'data/stories/closetheloop.yml', 'data/stories/faqs.yml', 'data/stories/handoff.yml', 'data/stories/oos.yml', 'data/stories/step1_get_started.yml', 'data/stories/step2_rasa_init.yml', 'data/stories/step3_rasa_x.yml', 'data/stories/step4_community.yml', 'data/stories/stories.yml']
在配置文件中,共需要配置 3 个键,分别是 “policies”、“language” 以及 “pipeline”,除了 “language” 必须要配置外,另外两个都可以不配置,而采用默认的配置内容。当用户提供这些配置则采用用户配置的内容,否则需要从默认配置文件中复制相应的配置项内容到 config.yml
中,自动配置项位于文件 rasa.shared.importers.default_config.yml
中。
重新回到上面的 load_from_config()-->load_from_dict()
中,最后一行:
return E2EImporter(ResponsesSyncImporter(CombinedDataImporter(importers)))
即,最后使用的文件加载器是在 “RasaFileImporter” 的基础上又嵌套初始化了 3 种不同的加载器类:
-
CombinedDataImporter:使用多个 “TrainingDataImporter” 实例加载数据,就像它们是单个实例一样。实际上就是说当你配置了多个文件加载器,例如,当你配置了 A 文件加载器可以加载处理 “a”、“b”、“c” 文件,还配置了 B 文件加载器可以加载 “x”、“y“、“z” 文件,则 ”CombinedDataImporter“ 的目的就是透明化 A 与 B 的存在,看起来像是一个文件加载器可以同时加载 ”abcxyz“ 六个文件。
async def get_config(self) -> Dict: configs = [importer.get_config() for importer in self._importers] configs = await asyncio.gather(*configs) return reduce(lambda merged, other: {**merged, **(other or {})}, configs, {})
仅以加载 config 的方法可以看到,其他几个方法完全一致,都是三部曲:遍历配置的每种加载器,分别进行加载,通过
asyncio.gather
将结果汇总,最后通过 merged 进行合并。 -
ResponsesSyncImporter:在 NLU 训练数据和 domain 之间同步 responses 模板。
-
E2EImporter:使用故事中的 action 或用户消息增强 NLU 训练数据以及将来自故事的潜在端到端bot消息作为操作添加到域中。
五、Train —— domain 解析
# rasa/train.py
async def train_async():
with TempDirectoryPath(tempfile.mkdtemp()) as train_path:
domain = await file_importer.get_domain()
...
在这里,利用 with 上下文管理器创建了一个临时工作目录,用于存储模型训练过程中的一些文件,关于 ”with“ 语法的基本使用,这里稍微提一下,当 with 语句后面的内容(tempfile.mkdtemp()
)被求值后,就会进入 TempDirectoryPath()
类下的 __enter__()
方法,将结果返回给 as
后的变量,即 train_path
,当 with 语句块完成后,再次进入到 __exit__()
方法中,一般进行退出句柄等操作。with
的最常见用法是用于文件 IO。
接下来,我们进入到这一小节的重点,即 get_domain
,对 domain 文件的内容进行解析。
缓存
按照上面的讲解,此时,会进入到 E2eImporter.get_domain()
中,但是需要注意的是,E2eImporter.get_domain()
被 rasa.shared.utils.common.cached_method()
所装饰,关于装饰器的内容,此处不做过多展开,这里装饰器的目的是对该函数的结果进行缓存,以便后续再执行该函数时,可以直接从缓存中加载,而不必再次执行该函数。
具体的缓存过程其实非常简单,记录缓存的类对象,要缓存的方法以及该方法的参数,参见下面代码:
class Cache:
"""Helper class to abstract the caching details."""
def __init__(self, caching_object: object, args: Any, kwargs: Any) -> None:
self.caching_object = caching_object
self.cache = getattr(caching_object, self._cache_name(), {})
# noinspection PyUnresolvedReferences
self.cache_key = functools._make_key(args, kwargs, typed=False)
这里的 caching_object
是指缓存的类对象,例如 E2EImporter
类对象,而 cache
指的是要缓存该类对象下的哪个方法值,例如 get_domain()
方法。这里的 cache_name()
实际上就是类名称与类方法名称拼接起来的,如 “_cached_E2EImporter_get_domain”,而 cache_key
是由参数构建的一个 key,毕竟不同参数得到的函数结果值是不一致的,所以要以参数的组合作为唯一的 key 进行结果缓存。
数据的加载
在进入到 E2eImporter.get_domain()
后,由于这里采用的是 async/await 协程写法,所以在具体的执行上有点类似多进程的同时运行多个内容。
E2eImporter
: A
get_domain()
: A_1get_stories()
: A_2_get_domain_with_e2e_actions()
: A_3
ResponsesSyncImporter
: B
get_domain()
: B_1get_stories()
: B_2
CombinedDataImporter
: C
get_domain()
: C_1get_stories()
: C_2
RasaFileImporter
: D
get_domain()
: C_1get_stories()
: D_2
A_1 --> A_3 --> A_2 --> B_2 --> C_2 --> D_2 --> D_1
A_1 -----------------------> B_1 --> C_1 -------------> D_1
我们最终将焦点放在最后一级 RasaFileImporter.get_domain()
上,
async def get_domain(self) -> Domain:
domain = Domain.empty()
if not self._domain_path:
return domain
try:
domain = Domain.load(self._domain_path)
...
一开始首先会初始化一个空的 domain,这个初始化过程我们可以忽略掉,因为具体的逻辑在后续实参解析过程中都会涉及到,且初始化过程中大部分内容均为空。我们重点从 load()
函数开始,这里的 _domain_path
就是我们自定义的 domain.yml
文件。
def load():
domain = Domain.empty()
for path in paths:
other = cls.from_path(path)
domain = domain.merge(other)
由于 domain 文件可以分散在多个文件中,因此这里采用遍历方式读取所有可能的文件,并对所有子部分进行合并(但是我们通常都是采用单文件形式,所以合并的仅仅是初始化的空域)。
具体的加载过程是,首先遍历可能的文件夹或文件(from_path
、from_file
),读取文件内容后 from_yaml
,对 yaml 格式进行验证后,数据转为 dict 格式,再次 from_dict
def from_dict(cls, data: Dict) -> "Domain":
utter_templates = data.get(KEY_RESPONSES, {})
slots = cls.collect_slots(data.get(KEY_SLOTS, {}))
additional_arguments = data.get("config", {})
session_config = cls._get_session_config(data.get(SESSION_CONFIG_KEY, {}))
intents = data.get(KEY_INTENTS, {})
forms = data.get(KEY_FORMS, {})
_validate_slot_mappings(forms)
return cls(...)
在这里,将 domain.yml
配置文件中的各配置项进行拆解,
-
utter_templates
: dict,即 responses 部分的配置内容,这是一种较为简单的 action 应答话术,139 项 -
slots
:dict,即 slots 部分的配置内容,槽位名称,槽位类型等,25 项,需要注意的是,这里的cls.collect_slots()
方法的目的就是根据配置的槽位类型找到具体的槽位子类进行槽位初始化。def collect_slots(slot_dict: Dict[Text, Any]) -> List[Slot]: slots = [] slot_dict = copy.deepcopy(slot_dict) for slot_name in slot_dict: slot_type = slot_dict[slot_name].pop("type", None) slot_class = Slot.resolve_by_type(slot_type) # 根据配置的槽位类型找到指定的子类 slot = slot_class(slot_name, **slot_dict[slot_name]) slots.append(slot) return slots
上述函数中,最为重要的方法是
Slot.resolve_by_type(slot_type)
,它可以根据你配置文件中配置的槽位类型找到处理该类型槽位的具体槽位子类,例如,槽位 “budget” 的槽位类型是 “any”,则这里的slot_class
指的便是<class 'rasa.shared.core.slots.AnySlot'>
,接下来对其进行初始化。class Slot: type_name = None def __init__( self, name: Text, initial_value: Any = None, value_reset_delay: Optional[int] = None, auto_fill: bool = True, influence_conversation: bool = True, ) -> None: self.name = name self.value = initial_value self.initial_value = initial_value self._value_reset_delay = value_reset_delay self.auto_fill = auto_fill self.influence_conversation = influence_conversation
type_name
:它在__init__
之外,即在导入该文件时,便会执行该语句,这也是为什么程序能根据它找到具体类型的子类name
:槽位名称initial_value
:槽位的初始值,槽位在 rasa 中充当的是全局变量的作用,所以可以初始化的value_reset_delay
:经过多少轮会话后,该槽位的值会被强制性恢复到初始值,该功能在该版本下暂未实现,可能会作为 rasa 未来版本的一个新功能auto_fill
:当有一个实体名称与该槽位名称一致,会用实体的值自动填充该槽位influence_conversation
:是否会被特征化处理,从而影响会话策略在 action 上的预测,“any” 类型不能被特征化处理
-
additional_arguments
:读取的是 config 部分的配置内容,通常在domain.yml
下不会进行此项配置, {} -
session_config
:会话配置项,会话到期时间设置为 8 个小时,槽位值会被带到新会话中等 -
intents
:list,“intents” 部分的配置内容,所有的意图命令列表,40 个 -
forms
:dict,“forms” 部分的表单配置,3个 -
_validate_slot_mappings(forms)
,针对表单,验证槽位映射配置的是否准确,主要是语法层面的检查- 如果配置的是
from_entity
,下面必须要配置的 key 是 “entity” - 如果配置的是
from_intent
或from_trigger_intent
,则必要要配置的 key 是 “value” - 如果配置的是
from_text
,则没有必须要配置的 key
- 如果配置的是
-
最后将上述各个部分为了参数进行
Domain()
类初始化
初始化
实体属性处理
实体是可以有一些属性的,例如,实体 roles、groups 的概念,一般情况我们不会配置这两项内容,如果配置了,这里解析后的 roles 以及 groups 都是 dict,即 {“实体名”: “角色名”} 和 {“实体名”: “实体组名”};
意图属性处理
意图在配置的时候也可以指定属性,如 use_entities
,指定该意图需要使用的实体有哪些,还有 ignore_entities
,指定需要忽略的实体,然而 use_entities
仅仅是用户在配置文件中使用的 Key,在 rasa 内部实际使用的属性却是 used_entities
,所以这里存在一个转化过程。另外 rasa 内部实际也有很多的默认意图,例如 “restart”,也会在这个阶段进行处理。同时需要保证不存在重复意图。
- 如果用户没有设置
use_entities
,则默认是所有的实体均需要使用,则将所有的实体都作为included_entities
,以及将实体与其对应的角色或实体组进行拼接,作为新的实体,也加入到included_entities
中; 拼接符号是 “#” - 如果用户明确指定了要包含的实体,则遍历
use_entities
,将里面的实体与该实体对应的角色,实体组拼接后加入到included_entities
中;拼接符号是 “#” excluded_entities
也是上述处理逻辑,只不过针对的是ignore_entities
,即应该被忽略的那部分- 那么
included_entities
-excluded_entities
就是 rasa 内部使用的used_entities
关于这部分的处理逻辑,更多内容详见 _transform_intent_properties_for_internal_use()
函数。整体下来,self.intent_properties
包含 43 项,由于默认意图可以被覆写,所以部分默认意图前面已经包含在 intents
列表中了。
处理被覆写的默认意图
rasa 默认意图共有 5 个,“restart”、“back”、“out_of_scope”、“session_start”、“nlu_fallback”,通过两个集合求交集的方式便可轻松得知被覆写的意图有 “out_of_scope” 和 “restart”。
初始化表单
返回表单名称列表、表单详情内容、被覆写的表单 action(一般都是空列表)
动作合并
合并 actions 部分与 responses 部分的内容,均作为 sef.user_actions
,共 24+139=163 项,然后继续与 rasa 默认的 actions 进行合并,在具体的合并逻辑中,可能有点反直觉,一般我们可能会认为,当一个默认意图被覆写后应当从默认意图列表中将其删除,然后将剩下的默认意图进行拼接。实际中,为了考虑一些默认意图的实际列表位置,我们是将重复项从用户的 actions 列表中将其删除,然后将剩下的用户 actions 列表拼接在默认 actions 列表后,具体实现见 _combine_user_with_default_actions()
函数。其中,默认的 actions 共有 11 项,见 rasa.shared.core.constants.DEFAULT_ACTION_NAMES,
至此共有 172 项。接下来还需要进行合并表单名称列表(3 项),以及可能存在的 “e2e_actions”(一般均不存在该项),所以整理下来,全部的 actions 组合完毕后,self.action_names_or_texts
共有 175 项。
添加默认槽位
主要包含 3 个部分:
- 添加需要请求的槽位:requested_slot,
TextSlot(requested_slot, influence_conversation=False)
- 添加基于知识的槽位:现阶段还是实验特性,不包含该部分内容
- 对分类槽位添加默认值:默认值为
__other__
最后还会对 domain 再一次进行检查,确保 domain 配置是正确的,不含任何重复的槽位、意图、动作或实体等,以及意图-动作映射的名称不正确或话语模板丢失等情况。详见 _check_domain_sanity()
函数内容。
在 domain 解析的最后,如果 domain 分散在多个文件中,需要对各个文件内容解析后进行 merge
,主要涉及的内容就是列表或字典数据的合并,我们不再赘述,详见 merge()
函数内容即可。至此关于 “domain” 的解析工作基本完成。
接下来,我们会从上面 “数据的加载”小结中的 “D_1” 回退到 “D_2” 中,即 RasaFileImporter.get_stories()