Rasa实体抽取之CRFEntityExtractor

Rasa实体抽取之CRFEntityExtractor

class CRFToken:
    def __init__(
        self,
        text: Text,
        pos_tag: Text,
        pattern: Dict[Text, Any],
        dense_features: np.ndarray,
        entity_tag: Text,
        entity_role_tag: Text,
        entity_group_tag: Text,
    ):
        self.text = text
        self.pos_tag = pos_tag
        self.pattern = pattern
        self.dense_features = dense_features
        self.entity_tag = entity_tag
        self.entity_role_tag = entity_role_tag
        self.entity_group_tag = entity_group_tag

class EntityExtractor(Component):
    def add_extractor_name(
        self, entities: List[Dict[Text, Any]]
    ) -> List[Dict[Text, Any]]:
        """Adds this extractor's name to a list of entities.

        Args:
            entities: the extracted entities.

        Returns:
            the modified entities.
        """
        for entity in entities:
            entity[EXTRACTOR] = self.name
        return entities

    def add_processor_name(self, entity: Dict[Text, Any]) -> Dict[Text, Any]:
        """Adds this extractor's name to the list of processors for this entity.

        Args:
            entity: the extracted entity and its metadata.

        Returns:
            the modified entity.
        """
        if "processors" in entity:
            entity["processors"].append(self.name)
        else:
            entity["processors"] = [self.name]

        return entity

    def init_split_entities(self) -> Dict[Text, bool]:
        split_entities_config = self.component_config.get(
            SPLIT_ENTITIES_BY_COMMA, SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE
        )
        default_value = self.defaults.get(
            SPLIT_ENTITIES_BY_COMMA, SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE
        )
        return rasa.utils.train_utils.init_split_entities(
            split_entities_config, default_value
        )

    @staticmethod
    def filter_irrelevant_entities(extracted: list, requested_dimensions: set) -> list:
        """Only return dimensions the user configured."""
        if requested_dimensions:
            return [
                entity
                for entity in extracted
                if entity[ENTITY_ATTRIBUTE_TYPE] in requested_dimensions
            ]
        return extracted

    @staticmethod
    def find_entity(
        entity: Dict[Text, Any], text: Text, tokens: List[Token]
    ) -> Tuple[int, int]:
        offsets = [token.start for token in tokens]
        ends = [token.end for token in tokens]

        if entity[ENTITY_ATTRIBUTE_START] not in offsets:
            message = (
                "Invalid entity {} in example '{}': "
                "entities must span whole tokens. "
                "Wrong entity start.".format(entity, text)
            )
            raise ValueError(message)

        if entity[ENTITY_ATTRIBUTE_END] not in ends:
            message = (
                "Invalid entity {} in example '{}': "
                "entities must span whole tokens. "
                "Wrong entity end.".format(entity, text)
            )
            raise ValueError(message)

        start = offsets.index(entity[ENTITY_ATTRIBUTE_START])
        end = ends.index(entity[ENTITY_ATTRIBUTE_END]) + 1
        return start, end

    def filter_trainable_entities(
        self, entity_examples: List[Message]
    ) -> List[Message]:
        """过滤掉不可训练的实体注释. """
        # 创建 entity_examples 的副本,其中删除了将 `extractor` 设置为 self.name 以外的内容(例如,'CRFEntityExtractor')的实体。
        filtered = []
        for message in entity_examples:
            entities = []
            for ent in message.get(ENTITIES, []):
                extractor = ent.get(EXTRACTOR)
                if not extractor or extractor == self.name:
                    entities.append(ent)
            data = message.data.copy()
            data[ENTITIES] = entities
            filtered.append(
                Message(
                    text=message.get(TEXT),
                    data=data,
                    output_properties=message.output_properties,
                    time=message.time,
                    features=message.features,
                )
            )

        return filtered

    @staticmethod
    def convert_predictions_into_entities(
        text: Text,
        tokens: List[Token],
        tags: Dict[Text, List[Text]],
        split_entities_config: Dict[Text, bool] = None,
        confidences: Optional[Dict[Text, List[float]]] = None,
    ) -> List[Dict[Text, Any]]:
        """将预测转换为实体。"""
        import rasa.nlu.utils.bilou_utils as bilou_utils

        entities = []

        last_entity_tag = NO_ENTITY_TAG
        last_role_tag = NO_ENTITY_TAG
        last_group_tag = NO_ENTITY_TAG
        last_token_end = -1

        for idx, token in enumerate(tokens):
            current_entity_tag = EntityExtractor.get_tag_for(
                tags, ENTITY_ATTRIBUTE_TYPE, idx
            )

            if current_entity_tag == NO_ENTITY_TAG:
                last_entity_tag = NO_ENTITY_TAG
                last_token_end = token.end
                continue

            current_group_tag = EntityExtractor.get_tag_for(
                tags, ENTITY_ATTRIBUTE_GROUP, idx
            )
            current_group_tag = bilou_utils.tag_without_prefix(current_group_tag)
            current_role_tag = EntityExtractor.get_tag_for(
                tags, ENTITY_ATTRIBUTE_ROLE, idx
            )
            current_role_tag = bilou_utils.tag_without_prefix(current_role_tag)

            group_or_role_changed = (
                last_group_tag != current_group_tag or last_role_tag != current_role_tag
            )

            if bilou_utils.bilou_prefix_from_tag(current_entity_tag):
                # 检查新的 bilou 标签
                # 新的 bilou 标签不以 I- , L- 标签开头
                new_bilou_tag_starts = last_entity_tag != current_entity_tag and (
                    bilou_utils.LAST
                    != bilou_utils.bilou_prefix_from_tag(current_entity_tag)
                    and bilou_utils.INSIDE
                    != bilou_utils.bilou_prefix_from_tag(current_entity_tag)
                )

                # 处理 bilou 标签,例如只有 I-、L- 标签而没有 B-标签
                # 并连续处理多个 U 标签
                new_unigram_bilou_tag_starts = (
                    last_entity_tag == NO_ENTITY_TAG
                    or bilou_utils.UNIT
                    == bilou_utils.bilou_prefix_from_tag(current_entity_tag)
                )

                new_tag_found = (
                    new_bilou_tag_starts
                    or new_unigram_bilou_tag_starts
                    or group_or_role_changed
                )
                last_entity_tag = current_entity_tag
                current_entity_tag = bilou_utils.tag_without_prefix(current_entity_tag)
            else:
                new_tag_found = (
                    last_entity_tag != current_entity_tag or group_or_role_changed
                )
                last_entity_tag = current_entity_tag

            if new_tag_found:
                # new entity found
                entity = EntityExtractor._create_new_entity(
                    list(tags.keys()),
                    current_entity_tag,
                    current_group_tag,
                    current_role_tag,
                    token,
                    idx,
                    confidences,
                )
                entities.append(entity)
            elif EntityExtractor._check_is_single_entity(
                text, token, last_token_end, split_entities_config, current_entity_tag
            ):
               # 当前令牌与之前的令牌具有相同的实体标签,并且两个令牌最多由 3 个符号分隔,其中每个符号必须是标点符号(例如“.”或“,”)和一个空格。
                entities[-1][ENTITY_ATTRIBUTE_END] = token.end
                if confidences is not None:
                    EntityExtractor._update_confidence_values(
                        entities, confidences, idx
                    )

            else:
              # 该令牌与之前的令牌具有相同的实体标签,但两个令牌至少由 2 个符号分隔(例如多个空格、逗号和空格等),并且也不应表示为单个实体
                entity = EntityExtractor._create_new_entity(
                    list(tags.keys()),
                    current_entity_tag,
                    current_group_tag,
                    current_role_tag,
                    token,
                    idx,
                    confidences,
                )
                entities.append(entity)

            last_group_tag = current_group_tag
            last_role_tag = current_role_tag
            last_token_end = token.end

        for entity in entities:
            entity[ENTITY_ATTRIBUTE_VALUE] = text[
                entity[ENTITY_ATTRIBUTE_START] : entity[ENTITY_ATTRIBUTE_END]
            ]

        return entities

    @staticmethod
    def _update_confidence_values(
        entities: List[Dict[Text, Any]], confidences: Dict[Text, List[float]], idx: int
    ):
        # use the lower confidence value
        entities[-1][ENTITY_ATTRIBUTE_CONFIDENCE_TYPE] = min(
            entities[-1][ENTITY_ATTRIBUTE_CONFIDENCE_TYPE],
            confidences[ENTITY_ATTRIBUTE_TYPE][idx],
        )
        if ENTITY_ATTRIBUTE_ROLE in entities[-1]:
            entities[-1][ENTITY_ATTRIBUTE_CONFIDENCE_ROLE] = min(
                entities[-1][ENTITY_ATTRIBUTE_CONFIDENCE_ROLE],
                confidences[ENTITY_ATTRIBUTE_ROLE][idx],
            )
        if ENTITY_ATTRIBUTE_GROUP in entities[-1]:
            entities[-1][ENTITY_ATTRIBUTE_CONFIDENCE_GROUP] = min(
                entities[-1][ENTITY_ATTRIBUTE_CONFIDENCE_GROUP],
                confidences[ENTITY_ATTRIBUTE_GROUP][idx],
            )

    @staticmethod
    def _check_is_single_entity(
        text: Text,
        token: Token,
        last_token_end: int,
        split_entities_config: Dict[Text, bool],
        current_entity_tag: Text,
    ):
# 当前令牌与之前的令牌具有相同的实体标签,并且两个令牌最多仅由一个符号分隔(例如空格、破折号等)
        if token.start - last_token_end <= 1:
            return True

        # Tokens 相距不超过 3 个位置
        # 选择magic number 3 以便可以提取以下两种情况
        #   - Schönhauser Allee 175, 10119 Berlin (由 2 个标记 (", ") 分隔的地址组合)
        #   - 22 Powderhall Rd., EH7 4GB (缩写“Rd”导致3个标记(“.,”)的分离)
        # 超过 3 个可能已经引入了这种逻辑不应考虑的情况
        tokens_within_range = token.start - last_token_end <= 3

        # 交互标记*must*是句号、逗号或空格
        interleaving_text = text[last_token_end : token.start]
        tokens_separated_by_allowed_chars = all(
            filter(
                lambda char: True
                if char in SINGLE_ENTITY_ALLOWED_INTERLEAVING_CHARSET
                else False,
                interleaving_text,
            )
        )

        # 当前实体类型必须与配置匹配(默认值为 True)
        default_value = split_entities_config[SPLIT_ENTITIES_BY_COMMA]
        split_current_entity_type = split_entities_config.get(
            current_entity_tag, default_value
        )

        return (
            tokens_within_range
            and tokens_separated_by_allowed_chars
            and not split_current_entity_type
        )

    @staticmethod
    def get_tag_for(tags: Dict[Text, List[Text]], tag_name: Text, idx: int) -> Text:
        """从标签列表中获取给定标签名称的值."""
        if tag_name in tags:
            return tags[tag_name][idx]
        return NO_ENTITY_TAG

    @staticmethod
    def _create_new_entity(
        tag_names: List[Text],
        entity_tag: Text,
        group_tag: Text,
        role_tag: Text,
        token: Token,
        idx: int,
        confidences: Optional[Dict[Text, List[float]]] = None,
    ) -> Dict[Text, Any]:
        """创建新的Entity"""
        entity = {
            ENTITY_ATTRIBUTE_TYPE: entity_tag,
            ENTITY_ATTRIBUTE_START: token.start,
            ENTITY_ATTRIBUTE_END: token.end,
        }

        if confidences is not None:
            entity[ENTITY_ATTRIBUTE_CONFIDENCE_TYPE] = confidences[
                ENTITY_ATTRIBUTE_TYPE
            ][idx]

        if ENTITY_ATTRIBUTE_ROLE in tag_names and role_tag != NO_ENTITY_TAG:
            entity[ENTITY_ATTRIBUTE_ROLE] = role_tag
            if confidences is not None:
                entity[ENTITY_ATTRIBUTE_CONFIDENCE_ROLE] = confidences[
                    ENTITY_ATTRIBUTE_ROLE
                ][idx]
        if ENTITY_ATTRIBUTE_GROUP in tag_names and group_tag != NO_ENTITY_TAG:
            entity[ENTITY_ATTRIBUTE_GROUP] = group_tag
            if confidences is not None:
                entity[ENTITY_ATTRIBUTE_CONFIDENCE_GROUP] = confidences[
                    ENTITY_ATTRIBUTE_GROUP
                ][idx]

        return entity

    @staticmethod
    def check_correct_entity_annotations(training_data: TrainingData) -> None:
        """检查实体是否在训练数据中正确注释。"""
        for example in training_data.entity_examples: # entity_examples
            entity_boundaries = [
                (entity[ENTITY_ATTRIBUTE_START], entity[ENTITY_ATTRIBUTE_END])
                for entity in example.get(ENTITIES)
            ]
            token_start_positions = [t.start for t in example.get(TOKENS_NAMES[TEXT])] # 来自tokenizer后的分词记录text_token
            token_end_positions = [t.end for t in example.get(TOKENS_NAMES[TEXT])]

            for entity_start, entity_end in entity_boundaries:
                if (
                    entity_start not in token_start_positions# 检测定义的分词边界entity_boundaries与代码jieba分词边界是否一致
                    or entity_end not in token_end_positions
                ):
                    entities_repr = [
                        (
                            entity[ENTITY_ATTRIBUTE_START],
                            entity[ENTITY_ATTRIBUTE_END],
                            entity[ENTITY_ATTRIBUTE_VALUE],
                        )
                        for entity in example.get(ENTITIES)
                    ]
                    tokens_repr = [
                        (t.start, t.end, t.text)
                        for t in example.get(TOKENS_NAMES[TEXT])
                    ]
                    rasa.shared.utils.io.raise_warning(
                        f"Misaligned entity annotation in message '{example.get(TEXT)}' "
                        f"with intent '{example.get(INTENT)}'. Make sure the start and "
                        f"end values of entities ({entities_repr}) in the training "
                        f"data match the token boundaries ({tokens_repr}). "
                        "Common causes: \n  1) entities include trailing whitespaces or punctuation"
                        "\n  2) the tokenizer gives an unexpected result, due to "
                        "languages such as Chinese that don't use whitespace for word separation",
                        docs=DOCS_URL_TRAINING_DATA_NLU,
                    )
                    break

rasa/nlu/extractors/crf_entity_extractor.py

class CRFEntityExtractor(EntityExtractor):
    @classmethod
    def required_components(cls) -> List[Type[Component]]:
        return [Tokenizer]

    defaults = {
        # BILOU_flag 决定是否使用 BILOU 标记。
        # 更严格但是每个实体需要更多的例子
        # 经验法则:仅当每个实体超过 100 个 egs 时才使用。 
        BILOU_FLAG: True,
        ## 用逗号分割实体,这是有道理的,例如 对于接收者中的成分列表,但对于地址的部分没有意义
        SPLIT_ENTITIES_BY_COMMA: True,
## crf_features 是 [before, token, after] 数组,其中包含 before、token、after 保存关于每个标记使用哪些特征的键,例如,数组中的“title” before 将具有“是标题大小写中的前一个标记?”的特征。
## POS 功能需要 SpacyTokenizer
## pattern 功能需要 RegexFeaturizer
        "features": [
            ["low", "title", "upper"],
            [
                "low",
                "bias",
                "prefix5",
                "prefix2",
                "suffix5",
                "suffix3",
                "suffix2",
                "upper",
                "title",
                "digit",
                "pattern",
            ],
            ["low", "title", "upper"],
        ],
        # 优化算法的最大迭代次数。
        "max_iterations": 50,
        # L1 正则化的权重
        "L1_c": 0.1,
        # L2 正则化的权重
        "L2_c": 0.1,
        # 要使用的密集特征的名称。 
        # 如果列表为空,则使用所有可用的密集特征。
        "featurizers": [],
    }

    function_dict: Dict[Text, Callable[[CRFToken], Any]] = {
        "low": lambda crf_token: crf_token.text.lower(),
        "title": lambda crf_token: crf_token.text.istitle(),
        "prefix5": lambda crf_token: crf_token.text[:5],
        "prefix2": lambda crf_token: crf_token.text[:2],
        "suffix5": lambda crf_token: crf_token.text[-5:],
        "suffix3": lambda crf_token: crf_token.text[-3:],
        "suffix2": lambda crf_token: crf_token.text[-2:],
        "suffix1": lambda crf_token: crf_token.text[-1:],
        "bias": lambda crf_token: "bias",
        "pos": lambda crf_token: crf_token.pos_tag,
        "pos2": lambda crf_token: crf_token.pos_tag[:2]
        if crf_token.pos_tag is not None
        else None,
        "upper": lambda crf_token: crf_token.text.isupper(),
        "digit": lambda crf_token: crf_token.text.isdigit(),
        "pattern": lambda crf_token: crf_token.pattern,
        "text_dense_features": lambda crf_token: crf_token.dense_features,
        "entity": lambda crf_token: crf_token.entity_tag,
    }

    def __init__(
        self,
        component_config: Optional[Dict[Text, Any]] = None,
        entity_taggers: Optional[Dict[Text, "CRF"]] = None,
    ) -> None:

        super().__init__(component_config)

        self.entity_taggers = entity_taggers

        self.crf_order = [
            ENTITY_ATTRIBUTE_TYPE,
            ENTITY_ATTRIBUTE_ROLE,
            ENTITY_ATTRIBUTE_GROUP,
        ]

        self._validate_configuration()

        self.split_entities_config = self.init_split_entities()

    def _validate_configuration(self) -> None:
        if len(self.component_config.get("features", [])) % 2 != 1:
            raise ValueError(
                "Need an odd number of crf feature lists to have a center word."
            )

    @classmethod
    def required_packages(cls) -> List[Text]:
        return ["sklearn_crfsuite", "sklearn"]

    def train(
        self,
        training_data: TrainingData,
        config: Optional[RasaNLUModelConfig] = None,
        **kwargs: Any,
    ) -> None:
        # checks whether there is at least one
        # example with an entity annotation
        if not training_data.entity_examples:
            logger.debug(
                "No training examples with entities present. Skip training"
                "of 'CRFEntityExtractor'."
            )
            return

        self.check_correct_entity_annotations(training_data) ## 校验 比如CRF前使用了结巴分词,那么结巴分词与自定义分词边界是否一致

        if self.component_config[BILOU_FLAG]:
            bilou_utils.apply_bilou_schema(training_data) ## 构建BILOU数据,存入Message对象中

        # only keep the CRFs for tags we actually have training data for 只保留我们有训练数据的标签的crf
        self._update_crf_order(training_data)

        # filter out pre-trained entity examples 过滤掉预先训练的实体示例
        entity_examples = self.filter_trainable_entities(training_data.nlu_examples)

        dataset = [self._convert_to_crf_tokens(example) for example in entity_examples] ## 返回list [[CRFToken对象,...],[CRFToken对象,...]]

        self._train_model(dataset)

    def _update_crf_order(self, training_data: TrainingData):
        """Train only CRFs we actually have training data for. 仅训练我们实际拥有训练数据的 CRF"""
        _crf_order = []

        for tag_name in self.crf_order: #['entity', 'role', 'group']
            if tag_name == ENTITY_ATTRIBUTE_TYPE and training_data.entities:
                _crf_order.append(ENTITY_ATTRIBUTE_TYPE)
            elif tag_name == ENTITY_ATTRIBUTE_ROLE and training_data.entity_roles:
                _crf_order.append(ENTITY_ATTRIBUTE_ROLE)
            elif tag_name == ENTITY_ATTRIBUTE_GROUP and training_data.entity_groups:
                _crf_order.append(ENTITY_ATTRIBUTE_GROUP)

        self.crf_order = _crf_order

    def process(self, message: Message, **kwargs: Any) -> None:
        entities = self.extract_entities(message)
        entities = self.add_extractor_name(entities)
        message.set(ENTITIES, message.get(ENTITIES, []) + entities, add_to_output=True)

    def extract_entities(self, message: Message) -> List[Dict[Text, Any]]:
        """Extract entities from the given message using the trained model(s)."""

        if self.entity_taggers is None:
            return []

        tokens = message.get(TOKENS_NAMES[TEXT])
        crf_tokens = self._convert_to_crf_tokens(message)

        predictions = {}
        for tag_name, entity_tagger in self.entity_taggers.items():
            # use predicted entity tags as features for second level CRFs
            include_tag_features = tag_name != ENTITY_ATTRIBUTE_TYPE
            if include_tag_features:
                self._add_tag_to_crf_token(crf_tokens, predictions)
            # [{'BOS': True, '0:low': '北京', '0:bias': 'bias', '0:prefix5': '北京', '0:prefix2': '北京', '0:suffix5': '北京', '0:suffix3': '北京', '0:suffix2': '北京', '0:upper': False, '0:title': False, '0:digit': False, '1:low': '的', '1:title': False, '1:upper': False}]
            features = self._crf_tokens_to_features(crf_tokens, include_tag_features)
            predictions[tag_name] = entity_tagger.predict_marginals_single(features)
            """
             CRF预测结果 {'entity': [{'U-date': 0.0018662581715369866, 'U-location': 0.9534361235771649, 'O': 0.04310048830917551, 'B-location': 0.00040871033087749956, 'I-location': 0.0003764602861231511, 'L-location': 0.000811959325121927}, 
                                    {'U-date': 0.005369950688885148, 'U-location': 0.004552219783974233, 'O': 0.9834001333595512, 'B-location': 0.0020663145611778053, 'I-location': 0.002047266321012076, 'L-location': 0.002564115285399473}, 
                                    {'U-date': 0.005864643891460612, 'U-location': 0.005886398421449579, 'O': 0.9790798152952706, 'B-location': 0.0032587770383554573, 'I-location': 0.0023428997898938143, 'L-location': 0.00356746556357009}, 
                                    {'U-date': 0.013875775748399285, 'U-location': 0.0142575192728364, 'O': 0.9510489275011496, 'B-location': 0.007400299042951769, 'I-location': 0.004890205794937807, 'L-location': 0.008527272639725128},
                                    {'U-date': 0.00491783928007722, 'U-location': 0.004268236726299186, 'O': 0.9804322682816534, 'B-location': 0.0036810716731708877, 'I-location': 0.0036080863158401088, 'L-location': 0.0030924977229592533}]}
            """
        # convert predictions into a list of tags and a list of confidences
        tags, confidences = self._tag_confidences(tokens, predictions)

        return self.convert_predictions_into_entities(
            message.get(TEXT), tokens, tags, self.split_entities_config, confidences
        )

    def _add_tag_to_crf_token(
        self,
        crf_tokens: List[CRFToken],
        predictions: Dict[Text, List[Dict[Text, float]]],
    ):
        """Add predicted entity tags to CRF tokens."""
        if ENTITY_ATTRIBUTE_TYPE in predictions:
            _tags, _ = self._most_likely_tag(predictions[ENTITY_ATTRIBUTE_TYPE])
            for tag, token in zip(_tags, crf_tokens):
                token.entity_tag = tag

    def _most_likely_tag(
        self, predictions: List[Dict[Text, float]]
    ) -> Tuple[List[Text], List[float]]:
        """获取置信度最高的实体标签。 """
        _tags = []
        _confidences = []

        for token_predictions in predictions:
            tag = max(token_predictions, key=lambda key: token_predictions[key])
            _tags.append(tag)

            if self.component_config[BILOU_FLAG]:
                # 如果我们使用 BILOU 标志,我们将总结实体的 B、I、L 和 U 标签的概率
                _confidences.append(
                    sum(
                        _confidence
                        for _tag, _confidence in token_predictions.items()
                        if bilou_utils.tag_without_prefix(tag)
                        == bilou_utils.tag_without_prefix(_tag)
                    )
                )
            else:
                _confidences.append(token_predictions[tag])

        return _tags, _confidences

    def _tag_confidences(
        self, tokens: List[Token], predictions: Dict[Text, List[Dict[Text, float]]]
    ) -> Tuple[Dict[Text, List[Text]], Dict[Text, List[float]]]:
        """使用标记的置信度值获取最有可能的标签预测。"""
        tags = {}
        confidences = {}

        for tag_name, predicted_tags in predictions.items():
            if len(tokens) != len(predicted_tags):
                raise Exception(
                    "Inconsistency in amount of tokens between crfsuite and message"
                )

            _tags, _confidences = self._most_likely_tag(predicted_tags)

            if self.component_config[BILOU_FLAG]:
                _tags, _confidences = bilou_utils.ensure_consistent_bilou_tagging(
                    _tags, _confidences
                )

            confidences[tag_name] = _confidences
            tags[tag_name] = _tags

        return tags, confidences

    @classmethod
    def load(
        cls,
        meta: Dict[Text, Any],
        model_dir: Text = None,
        model_metadata: Metadata = None,
        cached_component: Optional["CRFEntityExtractor"] = None,
        **kwargs: Any,
    ) -> "CRFEntityExtractor":
        import joblib

        file_names = meta.get("files")
        entity_taggers = {}

        if not file_names:
            logger.debug(
                f"Failed to load model for 'CRFEntityExtractor'. "
                f"Maybe you did not provide enough training data and no model was "
                f"trained or the path '{os.path.abspath(model_dir)}' doesn't exist?"
            )
            return cls(component_config=meta)

        for name, file_name in file_names.items():
            model_file = os.path.join(model_dir, file_name)
            if os.path.exists(model_file):
                entity_taggers[name] = joblib.load(model_file)
            else:
                logger.debug(
                    f"Failed to load model for tag '{name}' for 'CRFEntityExtractor'. "
                    f"Maybe you did not provide enough training data and no model was "
                    f"trained or the path '{os.path.abspath(model_file)}' doesn't "
                    f"exist?"
                )

        return cls(meta, entity_taggers)

    def persist(self, file_name: Text, model_dir: Text) -> Optional[Dict[Text, Any]]:
        """将此模型保存到传递的目录中。"""

        import joblib

        file_names = {}

        if self.entity_taggers:
            for name, entity_tagger in self.entity_taggers.items():
                file_name = f"{file_name}.{name}.pkl"
                model_file_name = os.path.join(model_dir, file_name)
                joblib.dump(entity_tagger, model_file_name)
                file_names[name] = file_name

        return {"files": file_names}

    def _crf_tokens_to_features(
        self, crf_tokens: List[CRFToken], include_tag_features: bool = False
    ) -> List[Dict[Text, Any]]:
        """将标记列表转换为离散特征"""
        configured_features = self.component_config["features"]
        sentence_features = []

        for token_idx in range(len(crf_tokens)):
            # the features for the current token include features of the token # 当前令牌的特征包括令牌的特征
            # before and after the current features (if defined in the config) # 当前功能之前和之后(如果在配置中定义)
            # token before (-1), current token (0), token after (+1) # 标记之前(-1),当前标记(0),标记之后(+1)
            window_size = len(configured_features)
            half_window_size = window_size // 2
            window_range = range(-half_window_size, half_window_size + 1)

            token_features = self._create_features_for_token(
                crf_tokens,
                token_idx,
                half_window_size,
                window_range,
                include_tag_features,
            ) ## 构建训练特征{'BOS': True, '0:low': '今天', '0:bias': 'bias', '0:prefix5': '今天', '0:prefix2': '今天', '0:suffix5': '今天', '0:suffix3': '今天', '0:suffix2': '今天', '0:upper': False, '0:title': False, '0:digit': False, '1:low': '上海', '1:title': False, '1:upper': False}

            sentence_features.append(token_features)

        return sentence_features

    def _create_features_for_token(
        self,
        crf_tokens: List[CRFToken],
        token_idx: int,
        half_window_size: int,
        window_range: range,
        include_tag_features: bool,
    ):
        """将标记转换为离散特征,包括词前和词后。"""

        configured_features = self.component_config["features"] # [['low', 'title', 'upper'], ['low', 'bias', 'prefix5', 'prefix2', 'suffix5', 'suffix3', 'suffix2', 'upper', 'title', 'digit', 'pattern'], ['low', 'title', 'upper']]
        prefixes = [str(i) for i in window_range] #[-1, 0, 1]

        token_features = {}

        # 迭代窗口范围 (-1, 0, +1) 中的标记以收集 token_idx 处标记的特征
        
        for pointer_position in window_range:
            current_token_idx = token_idx + pointer_position

            if current_token_idx >= len(crf_tokens):
                # 记号在句尾
                token_features["EOS"] = True
            elif current_token_idx < 0:
                # 标记在句子的开头
                token_features["BOS"] = True
            else:
                token = crf_tokens[current_token_idx]

                # 获取要为我们当前正在查看的令牌提取的特征
                current_feature_idx = pointer_position + half_window_size
                features = configured_features[current_feature_idx]

                prefix = prefixes[current_feature_idx]
         #我们添加了“实体”功能,将实体类型包含为role和group CRF功能(不要修改功能,否则我们最终会一遍又一遍地添加“实体”,使训练变得非常缓慢)
                additional_features = []
                if include_tag_features:
                    additional_features.append("entity")

                for feature in features + additional_features:
                    if feature == "pattern":
                       ## 添加从“RegexFeaturizer”中提取的所有正则表达式作为特征:“pattern_name”是用户在训练数据中设置的模式的名称,“matched”是“True”或“False”,具体取决于令牌是否实际匹配 模式与否
                        regex_patterns = self.function_dict[feature](token)
                        for pattern_name, matched in regex_patterns.items():
                            token_features[
                                f"{prefix}:{feature}:{pattern_name}"
                            ] = matched
                    else:
                        value = self.function_dict[feature](token)
                        token_features[f"{prefix}:{feature}"] = value

        return token_features

    @staticmethod
    def _crf_tokens_to_tags(crf_tokens: List[CRFToken], tag_name: Text) -> List[Text]:
        """Return the list of tags for the given tag name.  返回给定标签名称的标签列表"""
        if tag_name == ENTITY_ATTRIBUTE_ROLE:
            return [crf_token.entity_role_tag for crf_token in crf_tokens]
        if tag_name == ENTITY_ATTRIBUTE_GROUP:
            return [crf_token.entity_group_tag for crf_token in crf_tokens]

        return [crf_token.entity_tag for crf_token in crf_tokens]

    @staticmethod
    def _pattern_of_token(message: Message, idx: int) -> Dict[Text, bool]:
        """获取 'RegexFeaturizer 提取的给定索引处的令牌模式
        “RegexFeaturizer”将训练数据中列出的所有pattern 添加到令牌中。 pattern名称映射到“true”(pattern 适用于令牌)或“false”(pattern 不适用于令牌)。
        """
        if message.get(TOKENS_NAMES[TEXT]) is not None:
            return message.get(TOKENS_NAMES[TEXT])[idx].get("pattern", {})
        return {}

    def _get_dense_features(self, message: Message) -> Optional[List]:
        """将密集特征转换为 python-crfsuite 特征格式。"""
        features, _ = message.get_dense_features(
            TEXT, self.component_config["featurizers"]
        )# messages中self.features的值

        if features is None:
            return None

        tokens = message.get(TOKENS_NAMES[TEXT])
        if len(tokens) != len(features.features):
            rasa.shared.utils.io.raise_warning(
                f"Number of dense features ({len(features.features)}) for attribute "
                f"'TEXT' does not match number of tokens ({len(tokens)}).",
                docs=DOCS_URL_COMPONENTS + "#crfentityextractor",
            )
            return None

        # convert to python-crfsuite feature format
        features_out = []
        for feature in features.features:
            feature_dict = {
                str(index): token_features
                for index, token_features in enumerate(feature)
            }
            converted = {"text_dense_features": feature_dict}
            features_out.append(converted)

        return features_out

    def _convert_to_crf_tokens(self, message: Message) -> List[CRFToken]:
        """获取消息并将其转换为 crfsuite 格式。"""
        # 将消息转换成crfsuite格式
        crf_format = []
        tokens = message.get(TOKENS_NAMES[TEXT])

        text_dense_features = self._get_dense_features(message) #将密集特性转换为python-crfsuite特性格式
        tags = self._get_tags(message) # ['U-date', 'U-location', 'O', 'O', 'O']

        for i, token in enumerate(tokens):
            pattern = self._pattern_of_token(message, i)
            entity = self.get_tag_for(tags, ENTITY_ATTRIBUTE_TYPE, i)
            group = self.get_tag_for(tags, ENTITY_ATTRIBUTE_GROUP, i)
            role = self.get_tag_for(tags, ENTITY_ATTRIBUTE_ROLE, i)
            pos_tag = token.get(POS_TAG_KEY)
            dense_features = (
                text_dense_features[i] if text_dense_features is not None else []
            )

            crf_format.append(
                CRFToken(
                    text=token.text,
                    pos_tag=pos_tag,
                    entity_tag=entity,
                    entity_group_tag=group,
                    entity_role_tag=role,
                    pattern=pattern,
                    dense_features=dense_features,
                )
            )

        return crf_format # [CRFToken(),...]

    def _get_tags(self, message: Message) -> Dict[Text, List[Text]]:
        """获取分配的消息实体标签。"""
        tokens = message.get(TOKENS_NAMES[TEXT])
        tags = {}

        for tag_name in self.crf_order:
            if self.component_config[BILOU_FLAG]:
                bilou_key = bilou_utils.get_bilou_key_for_tag(tag_name)
                if message.get(bilou_key):
                    _tags = message.get(bilou_key)
                else:
                    _tags = [NO_ENTITY_TAG for _ in tokens]
            else:
                _tags = [
                    determine_token_labels(
                        token, message.get(ENTITIES), attribute_key=tag_name
                    )
                    for token in tokens
                ]
            tags[tag_name] = _tags

        return tags

    def _train_model(self, df_train: List[List[CRFToken]]) -> None:
        """Train the crf tagger based on the training data."""
        import sklearn_crfsuite

        self.entity_taggers = {}

        for tag_name in self.crf_order:
            logger.debug(f"Training CRF for '{tag_name}'.")

            # add entity tag features for second level CRFs
            include_tag_features = tag_name != ENTITY_ATTRIBUTE_TYPE
            X_train = [
                self._crf_tokens_to_features(sentence, include_tag_features)
                for sentence in df_train
            ]## 构建训练CRF特征 [[{'BOS': True, '0:low': '今天', '0:bias': 'bias', '0:prefix5': '今天', '0:prefix2': '今天', '0:suffix5': '今天', '0:suffix3': '今天', '0:suffix2': '今天', '0:upper': False, '0:title': False, '0:digit': False, '1:low': '上海', '1:title': False, '1:upper': False}, {'-1:low': '今天', '-1:title': False, '-1:upper': False, '0:low': '上海', '0:bias': 'bias', '0:prefix5': '上海', '0:prefix2': '上海', '0:suffix5': '上海', '0:suffix3': '上海', '0:suffix2': '上海', '0:upper': False, '0:title': False, '0:digit': False, '1:low': '的', '1:title': False, '1:upper': False}, {'-1:low': '上海', '-1:title': False, '-1:upper': False, '0:low': '的', '0:bias': 'bias', '0:prefix5': '的', '0:prefix2': '的', '0:suffix5': '的', '0:suffix3': '的', '0:suffix2': '的', '0:upper': False, '0:title': False, '0:digit': False, '1:low': '天气', '1:title': False, '1:upper': False}, {'-1:low': '的', '-1:title': False, '-1:upper': False, '0:low': '天气', '0:bias': 'bias', '0:prefix5': '天气', '0:prefix2': '天气', '0:suffix5': '天气', '0:suffix3': '天气', '0:suffix2': '天气', '0:upper': False, '0:title': False, '0:digit': False, '1:low': '怎么样', '1:title': False, '1:upper': False}, {'-1:low': '天气', '-1:title': False, '-1:upper': False, '0:low': '怎么样', '0:bias': 'bias', '0:prefix5': '怎么样', '0:prefix2': '怎么', '0:suffix5': '怎么样', '0:suffix3': '怎么样', '0:suffix2': '么样', '0:upper': False, '0:title': False, '0:digit': False, 'EOS': True}]]
            y_train = [
                self._crf_tokens_to_tags(sentence, tag_name) for sentence in df_train
            ] #[['U-date', 'U-location', 'O', 'O', 'O']]

            entity_tagger = sklearn_crfsuite.CRF(
                algorithm="lbfgs",
                # coefficient for L1 penalty
                c1=self.component_config["L1_c"],# l1正则
                # coefficient for L2 penalty
                c2=self.component_config["L2_c"], # l2正则
                # stop earlier
                max_iterations=self.component_config["max_iterations"], # 早停机制
                # include transitions that are possible, but not observed
                all_possible_transitions=True,
            )
            entity_tagger.fit(X_train, y_train)

            self.entity_taggers[tag_name] = entity_tagger

            logger.debug("Training finished.")

rasa/nlu/utils/bilou_utils.py
将token标记为BILOU的标注

def apply_bilou_schema_to_message(message: "Message") -> None:
    """获取 BILOU 实体标签列表并将它们设置在给定的消息上。"""
    entities = message.get(ENTITIES)

    if not entities:
        return

    tokens = message.get(TOKENS_NAMES[TEXT])

    for attribute, message_key in [
        (ENTITY_ATTRIBUTE_TYPE, BILOU_ENTITIES), # entity  bilou_entities
        (ENTITY_ATTRIBUTE_ROLE, BILOU_ENTITIES_ROLE), # role bilou_entities_role
        (ENTITY_ATTRIBUTE_GROUP, BILOU_ENTITIES_GROUP), # group  bilou_entities_group
    ]:
        entities = map_message_entities(message, attribute) #[(0, 2, 'date'), (2, 4, 'location')]
        output = bilou_tags_from_offsets(tokens, entities) # BILOU [U-date, U-location, O, O, O]  ROLE[O, O, O, O, O]  GROUP[O, O, O, O, O]
        message.set(message_key, output)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

发呆的比目鱼

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值