Rasa实体抽取和意图分类之DIETClassifier
rasa\nlu\classifiers\diet_classifier.py
DIETClassifier训练模型是diet_classifier文件中类DIET,DIET继承了类TransformerRasaModel。
训练(train)
- 函数 preprocess_train_data主要构建训练数据,主要来自之前组件的feature, 包含的属性有
text
,entity
和intent
class DIETClassifier(IntentClassifier, EntityExtractor):
def train(
self,
training_data: TrainingData,
config: Optional[RasaNLUModelConfig] = None,
**kwargs: Any,
) -> None:
"""Train the embedding intent classifier on a data set. 在数据集上训练嵌入意图分类器"""
model_data = self.preprocess_train_data(training_data)
if model_data.is_empty():
logger.debug(
f"Cannot train '{self.__class__.__name__}'. No data was provided. "
f"Skipping training of the classifier."
)
return
if self.component_config.get(INTENT_CLASSIFICATION):
if not self._check_enough_labels(model_data):
logger.error(
f"Cannot train '{self.__class__.__name__}'. "
f"Need at least 2 different intent classes. "
f"Skipping training of classifier."
)
return
if self.component_config.get(ENTITY_RECOGNITION):
self.check_correct_entity_annotations(training_data)
# 保留一个用于持久化和加载的示例
self._data_example = model_data.first_data_example()
if not self.finetune_mode:
# No pre-trained model to load from. Create a new instance of the model. 没有可从中加载的预训练模型。 创建模型的新实例。
self.model = self._instantiate_model_class(model_data)
self.model.fit(
model_data,
self.component_config[EPOCHS],
self.component_config[BATCH_SIZES],
self.component_config[EVAL_NUM_EXAMPLES],
self.component_config[EVAL_NUM_EPOCHS],
self.component_config[BATCH_STRATEGY],
)