nlu模型训练源码分析

nlu模型训练源码分析

rasa/train.py是模型训练的文件,_train_async_internal函数是训练nlu和core的入口,_train_nlu_with_validated_data是训练nlu的函数。

async def _train_async_internal( ### train and core 入口
    file_importer: TrainingDataImporter,
    train_path: Text,
    output_path: Text,
    dry_run: bool,
    force_training: bool,
    fixed_model_name: Optional[Text],
    persist_nlu_training_data: bool,
    core_additional_arguments: Optional[Dict] = None,
    nlu_additional_arguments: Optional[Dict] = None,
    model_to_finetune: Optional[Text] = None,
    finetuning_epoch_fraction: float = 1.0,
) -> TrainingResult:
     .....
     trained_model = await _train_nlu_with_validated_data( # train nlu
            file_importer,
            output=output_path,
            fixed_model_name=fixed_model_name,
            persist_nlu_training_data=persist_nlu_training_data,
            additional_arguments=nlu_additional_arguments,
            model_to_finetune=model_to_finetune,
            finetuning_epoch_fraction=finetuning_epoch_fraction,
        )
     .....
async def _train_nlu_with_validated_data( ## 命令rasa train 和rasa train nlu都经过的点
    file_importer: TrainingDataImporter,
    output: Text,
    train_path: Optional[Text] = None,
    fixed_model_name: Optional[Text] = None,
    persist_nlu_training_data: bool = False,
    additional_arguments: Optional[Dict] = None,
    model_to_finetune: Optional["Text"] = None,
    finetuning_epoch_fraction: float = 1.0,
) -> Optional[Text]:
    
    .....
        _train_path = stack.enter_context(TempDirectoryPath(tempfile.mkdtemp()))
        """
            {
            'language': 'zh',
            'pipeline': [{
                'name': 'JiebaTokenizer'
            }],
            'policies': [{
                'name': 'FormPolicy'
            }]
            }
        """
        config = await file_importer.get_config() #
        ....

        async with telemetry.track_model_training( # 上下文检测, 确保配置正确。
            file_importer,
            model_type="nlu",
            is_finetuning=model_to_finetune is not None,
        ): #具体执行训练的函数
            await rasa.nlu.train(
                config,  # 见上面config的注释
                file_importer, # NluDataImporter object
                _train_path,  # 临时目录: /tmp/tmp79v61479
                fixed_model_name="nlu",
                persist_nlu_training_data=persist_nlu_training_data,  # False
                model_to_finetune=model_to_finetune,
                **additional_arguments,
            )
  • 先读取配置文件 file_importer.get_config()
  • rasa.nlu.train是训练的函数
async def train(
    nlu_config: Union[Text, Dict, RasaNLUModelConfig],  # 见上面代码片段中的config
    data: Union[Text, "TrainingDataImporter"], # NluDataImporter object
    path: Optional[Text] = None,  # /tmp/tmp79v61479
    fixed_model_name: Optional[Text] = None,   # 'nlu'
    storage: Optional[Text] = None,
    component_builder: Optional[ComponentBuilder] = None,
    training_data_endpoint: Optional[EndpointConfig] = None,
    persist_nlu_training_data: bool = False,
    model_to_finetune: Optional[Interpreter] = None,
    **kwargs: Any,
) -> Tuple[Trainer, Interpreter, Optional[Text]]:
    """Loads the trainer and the data and runs the training of the model."""
    from rasa.shared.importers.importer import TrainingDataImporter
    if not isinstance(nlu_config, RasaNLUModelConfig):
        nlu_config = config.load(nlu_config) ##调用config.load,将配置文件config.yml的文件内容读出来,并保存到RasaNLUModelConfig对象中
    # Ensure we are training a model that we can save in the end
    # WARN: there is still a race condition if a model with the same name is
    # trained in another subprocess
    trainer = Trainer(nlu_config, component_builder, model_to_finetune=model_to_finetune) # Trainer对象
    persistor = create_persistor(storage)
    if training_data_endpoint is not None:
        training_data = await load_data_from_endpoint(training_data_endpoint, nlu_config.language)
    elif isinstance(data, TrainingDataImporter):
        training_data = await data.get_nlu_data(nlu_config.language)# 加载nlu数据
    else:
        training_data = load_data(data, nlu_config.language)  ## load_data会读取训练数据,都保存在TrainingData对象中,
        """
        training_data数据结构:
        training_examples: <rasa.shared.nlu.training_data.message.Message object at 0x7fc09f1eef10> 
        entity_synonyms: {'NYC': 'New York City', 'nyc': 'New York City'}
        regex_features: [{'name': 'zipcode', 'pattern': '[0-9]{5}'}]
        lookup_tables: [{'name': 'location', 'elements': ['嘉兴', '海宁', '哈尔滨', '绍兴', '嘉善']}]
        
        在看看数据结构的定义,training_examples是一个列表,每个元素是一个Message对象。
        Message object{
          data={'text': '查询北京的天气', 'intent': 'search_weather', 'entities': [{'start': 2, 'end': 4, 'value': '北京', 'entity': 'location'}]}
          features=[]
          output_properities={'text'}
          time = None
        }
        """
    training_data.print_stats()
    if training_data.entity_roles_groups_used():
        rasa.shared.utils.common.mark_as_experimental_feature("Entity Roles and Groups feature")
    interpreter = trainer.train(training_data, **kwargs)## trainer.train,依次调用pipeline每个组件的预处理函数和训练函数,
    if path: # 将模型的保存
        persisted_path = trainer.persist(path, persistor, fixed_model_name, persist_nlu_training_data)
    else:
        persisted_path = None
    return trainer, interpreter, persisted_path
  • trainer = Trainer(nlu_config, component_builder, model_to_finetune=model_to_finetune ) 加载训练对象
  • training_data = load_data(data, nlu_config.language) 读取数据,并保存在TrainingData对象中。
  • interpreter = trainer.train(training_data, **kwargs),进入训练函数,返回Interpreter对象
  • persisted_path = trainer.persist(path, persistor, fixed_model_name, persist_nlu_training_data) 保存模型
    def train(self, data: TrainingData, **kwargs: Any) -> "Interpreter":
        """Trains the underlying pipeline using the provided training data."""
        self.training_data = data
        self.training_data.validate()
        # 这里采用一个字典context,在训练开始前,每个组件都通过provide_context将需要提供的上下文信息更新到context中,大部分组件不需要提供预置信息。比如MITIE,spacy需要提供框架的环境变量,比如词向量。
        context = kwargs
        for component in self.pipeline:
            updates = component.provide_context()
            if updates:
                context.update(updates)
        # Before the training starts: check that all arguments are provided 在训练开始之前:检查是否提供了所有的参数
        if not self.skip_validation:
            components.validate_required_components_from_data(
                self.pipeline, self.training_data
            )
        # data gets modified internally during the training - hence the copy
        working_data: TrainingData = copy.deepcopy(data)
        for i, component in enumerate(self.pipeline):
            logger.info(f"Starting to train component {component.name}")
            component.prepare_partial_processing(self.pipeline[:i], context)## 加载训练前的上下文信息context
            updates = component.train(working_data, self.config, **context)
            logger.info("Finished training component.")
            if updates:
                context.update(updates)
        return Interpreter(self.pipeline, context)
  • context 保存训练的上下文
  • for i, component in enumerate(self.pipeline):循环遍历pipeline中的模型,开始训练
  • Interpreter(self.pipeline, context) 返回Interpreter对象
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

发呆的比目鱼

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值