nlu模型训练源码分析
rasa/train.py
是模型训练的文件,_train_async_internal
函数是训练nlu和core的入口,_train_nlu_with_validated_data
是训练nlu的函数。
async def _train_async_internal( ### train and core 入口
file_importer: TrainingDataImporter,
train_path: Text,
output_path: Text,
dry_run: bool,
force_training: bool,
fixed_model_name: Optional[Text],
persist_nlu_training_data: bool,
core_additional_arguments: Optional[Dict] = None,
nlu_additional_arguments: Optional[Dict] = None,
model_to_finetune: Optional[Text] = None,
finetuning_epoch_fraction: float = 1.0,
) -> TrainingResult:
.....
trained_model = await _train_nlu_with_validated_data( # train nlu
file_importer,
output=output_path,
fixed_model_name=fixed_model_name,
persist_nlu_training_data=persist_nlu_training_data,
additional_arguments=nlu_additional_arguments,
model_to_finetune=model_to_finetune,
finetuning_epoch_fraction=finetuning_epoch_fraction,
)
.....
async def _train_nlu_with_validated_data( ## 命令rasa train 和rasa train nlu都经过的点
file_importer: TrainingDataImporter,
output: Text,
train_path: Optional[Text] = None,
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
model_to_finetune: Optional["Text"] = None,
finetuning_epoch_fraction: float = 1.0,
) -> Optional[Text]:
.....
_train_path = stack.enter_context(TempDirectoryPath(tempfile.mkdtemp()))
"""
{
'language': 'zh',
'pipeline': [{
'name': 'JiebaTokenizer'
}],
'policies': [{
'name': 'FormPolicy'
}]
}
"""
config = await file_importer.get_config() #
....
async with telemetry.track_model_training( # 上下文检测, 确保配置正确。
file_importer,
model_type="nlu",
is_finetuning=model_to_finetune is not None,
): #具体执行训练的函数
await rasa.nlu.train(
config, # 见上面config的注释
file_importer, # NluDataImporter object
_train_path, # 临时目录: /tmp/tmp79v61479
fixed_model_name="nlu",
persist_nlu_training_data=persist_nlu_training_data, # False
model_to_finetune=model_to_finetune,
**additional_arguments,
)
- 先读取配置文件
file_importer.get_config()
rasa.nlu.train
是训练的函数
async def train(
nlu_config: Union[Text, Dict, RasaNLUModelConfig], # 见上面代码片段中的config
data: Union[Text, "TrainingDataImporter"], # NluDataImporter object
path: Optional[Text] = None, # /tmp/tmp79v61479
fixed_model_name: Optional[Text] = None, # 'nlu'
storage: Optional[Text] = None,
component_builder: Optional[ComponentBuilder] = None,
training_data_endpoint: Optional[EndpointConfig] = None,
persist_nlu_training_data: bool = False,
model_to_finetune: Optional[Interpreter] = None,
**kwargs: Any,
) -> Tuple[Trainer, Interpreter, Optional[Text]]:
"""Loads the trainer and the data and runs the training of the model."""
from rasa.shared.importers.importer import TrainingDataImporter
if not isinstance(nlu_config, RasaNLUModelConfig):
nlu_config = config.load(nlu_config) ##调用config.load,将配置文件config.yml的文件内容读出来,并保存到RasaNLUModelConfig对象中
# Ensure we are training a model that we can save in the end
# WARN: there is still a race condition if a model with the same name is
# trained in another subprocess
trainer = Trainer(nlu_config, component_builder, model_to_finetune=model_to_finetune) # Trainer对象
persistor = create_persistor(storage)
if training_data_endpoint is not None:
training_data = await load_data_from_endpoint(training_data_endpoint, nlu_config.language)
elif isinstance(data, TrainingDataImporter):
training_data = await data.get_nlu_data(nlu_config.language)# 加载nlu数据
else:
training_data = load_data(data, nlu_config.language) ## load_data会读取训练数据,都保存在TrainingData对象中,
"""
training_data数据结构:
training_examples: <rasa.shared.nlu.training_data.message.Message object at 0x7fc09f1eef10>
entity_synonyms: {'NYC': 'New York City', 'nyc': 'New York City'}
regex_features: [{'name': 'zipcode', 'pattern': '[0-9]{5}'}]
lookup_tables: [{'name': 'location', 'elements': ['嘉兴', '海宁', '哈尔滨', '绍兴', '嘉善']}]
在看看数据结构的定义,training_examples是一个列表,每个元素是一个Message对象。
Message object{
data={'text': '查询北京的天气', 'intent': 'search_weather', 'entities': [{'start': 2, 'end': 4, 'value': '北京', 'entity': 'location'}]}
features=[]
output_properities={'text'}
time = None
}
"""
training_data.print_stats()
if training_data.entity_roles_groups_used():
rasa.shared.utils.common.mark_as_experimental_feature("Entity Roles and Groups feature")
interpreter = trainer.train(training_data, **kwargs)## trainer.train,依次调用pipeline每个组件的预处理函数和训练函数,
if path: # 将模型的保存
persisted_path = trainer.persist(path, persistor, fixed_model_name, persist_nlu_training_data)
else:
persisted_path = None
return trainer, interpreter, persisted_path
trainer = Trainer(nlu_config, component_builder, model_to_finetune=model_to_finetune )
加载训练对象training_data = load_data(data, nlu_config.language)
读取数据,并保存在TrainingData对象中。interpreter = trainer.train(training_data, **kwargs)
,进入训练函数,返回Interpreter对象persisted_path = trainer.persist(path, persistor, fixed_model_name, persist_nlu_training_data)
保存模型
def train(self, data: TrainingData, **kwargs: Any) -> "Interpreter":
"""Trains the underlying pipeline using the provided training data."""
self.training_data = data
self.training_data.validate()
# 这里采用一个字典context,在训练开始前,每个组件都通过provide_context将需要提供的上下文信息更新到context中,大部分组件不需要提供预置信息。比如MITIE,spacy需要提供框架的环境变量,比如词向量。
context = kwargs
for component in self.pipeline:
updates = component.provide_context()
if updates:
context.update(updates)
# Before the training starts: check that all arguments are provided 在训练开始之前:检查是否提供了所有的参数
if not self.skip_validation:
components.validate_required_components_from_data(
self.pipeline, self.training_data
)
# data gets modified internally during the training - hence the copy
working_data: TrainingData = copy.deepcopy(data)
for i, component in enumerate(self.pipeline):
logger.info(f"Starting to train component {component.name}")
component.prepare_partial_processing(self.pipeline[:i], context)## 加载训练前的上下文信息context
updates = component.train(working_data, self.config, **context)
logger.info("Finished training component.")
if updates:
context.update(updates)
return Interpreter(self.pipeline, context)
context
保存训练的上下文for i, component in enumerate(self.pipeline):
循环遍历pipeline中的模型,开始训练Interpreter(self.pipeline, context)
返回Interpreter对象