Rasa实体抽取之CRFEntityExtractor
class CRFToken:
def __init__(
self,
text: Text,
pos_tag: Text,
pattern: Dict[Text, Any],
dense_features: np.ndarray,
entity_tag: Text,
entity_role_tag: Text,
entity_group_tag: Text,
):
self.text = text
self.pos_tag = pos_tag
self.pattern = pattern
self.dense_features = dense_features
self.entity_tag = entity_tag
self.entity_role_tag = entity_role_tag
self.entity_group_tag = entity_group_tag
class EntityExtractor(Component):
def add_extractor_name(
self, entities: List[Dict[Text, Any]]
) -> List[Dict[Text, Any]]:
"""Adds this extractor's name to a list of entities.
Args:
entities: the extracted entities.
Returns:
the modified entities.
"""
for entity in entities:
entity[EXTRACTOR] = self.name
return entities
def add_processor_name(self, entity: Dict[Text, Any]) -> Dict[Text, Any]:
"""Adds this extractor's name to the list of processors for this entity.
Args:
entity: the extracted entity and its metadata.
Returns:
the modified entity.
"""
if "processors" in entity:
entity["processors"].append(self.name)
else:
entity["processors"] = [self.name]
return entity
def init_split_entities(self) -> Dict[Text, bool]:
split_entities_config = self.component_config.get(
SPLIT_ENTITIES_BY_COMMA, SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE
)
default_value = self.defaults.get(
SPLIT_ENTITIES_BY_COMMA, SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE
)
return rasa.utils.train_utils.init_split_entities(
split_entities_config, default_value
)
@staticmethod
def filter_irrelevant_entities(extracted: list, requested_dimensions: set) -> list:
"""Only return dimensions the user configured."""
if requested_dimensions:
return [
entity
for entity in extracted
if entity[ENTITY_ATTRIBUTE_TYPE] in requested_dimensions
]
return extracted
@staticmethod
def find_entity(
entity: Dict[Text, Any], text: Text, tokens: List[Token]
) -> Tuple[int, int]:
offsets = [token.start for token in tokens]
ends = [token.end for token in tokens]
if entity[ENTITY_ATTRIBUTE_START] not in offsets:
message = (
"Invalid entity {} in example '{}': "
"entities must span whole tokens. "
"Wrong entity start.".format(entity, text)
)
raise ValueError(message)
if entity[ENTITY_ATTRIBUTE_END] not in ends:
message = (
"Invalid entity {} in example '{}': "
"entities must span whole tokens. "
"Wrong entity end.".format(entity, text)
)
raise ValueError(message)
start = offsets.index(entity[ENTITY_ATTRIBUTE_START])
end = ends.index(entity[ENTITY_ATTRIBUTE_END]) + 1
return start, end
def filter_trainable_entities(
self, entity_examples: List[Message]
) -> List[Message]:
"""过滤掉不可训练的实体注释. """
# 创建 entity_examples 的副本,其中删除了将 `extractor` 设置为 self.name 以外的内容(例如,'CRFEntityExtractor')的实体。
filtered = []
for message in entity_examples:
entities = []
for ent in message.get(ENTITIES, []):
extractor = ent.get(EXTRACTOR)
if not extractor or extractor == self.name:
entities.append(ent)
data = message.data.copy()
data[ENTITIES] = entities
filtered.append(
Message(
text=message.get(TEXT),
data=data,
output_properties=message.output_properties,
time=message.time,
features=message.features,
)
)
return filtered
@staticmethod
def convert_predictions_into_entities(
text: Text,
tokens: List[Token],
tags: Dict[Text, List[Text]],
split_entities_config: Dict[Text, bool] = None,
confidences: Optional[Dict[Text, List[float]]] = None,
) -> List[Dict[Text, Any]]:
"""将预测转换为实体。"""
import rasa.nlu.utils.bilou_utils as bilou_utils
entities = []
last_entity_tag = NO_ENTITY_TAG
last_role_tag = NO_ENTITY_TAG
last_group_tag = NO_ENTITY_TAG
last_token_end = -1
for idx, token in enumerate(tokens):
current_entity_tag = EntityExtractor.get_tag_for(
tags, ENTITY_ATTRIBUTE_TYPE, idx
)
if current_entity_tag == NO_ENTITY_TAG:
last_entity_tag = NO_ENTITY_TAG
last_token_end = token.end
continue
current_group_tag = EntityExtractor.get_tag_for(
tags, ENTITY_ATTRIBUTE_GROUP, idx
)
current_group_tag = bilou_utils.tag_without_prefix(current_group_tag)
current_role_tag = EntityExtractor.get_tag_for(
tags, ENTITY_ATTRIBUTE_ROLE, idx
)
current_role_tag = bilou_utils.tag_without_prefix(current_role_tag)
group_or_role_changed = (
last_group_tag != current_group_tag or last_role_tag != current_role_tag
)
if bilou_utils.bilou_prefix_from_tag(current_entity_tag):
# 检查新的 bilou 标签
# 新的 bilou 标签不以 I- , L- 标签开头
new_bilou_tag_starts = last_entity_tag != current_entity_tag and (
bilou_utils.LAST
!= bilou_utils.bilou_prefix_from_tag(current_entity_tag)
and bilou_utils.INSIDE
!= bilou_utils.bilou_prefix_from_tag(current_entity_tag)
)
# 处理 bilou 标签,例如只有 I-、L- 标签而没有 B-标签
# 并连续处理多个 U 标签
new_unigram_bilou_tag_starts = (
last_entity_tag == NO_ENTITY_TAG
or bilou_utils.UNIT
== bilou_utils.bilou_prefix_from_tag(current_entity_tag)
)
new_tag_found = (
new_bilou_tag_starts
or new_unigram_bilou_tag_starts
or group_or_role_changed
)
last_entity_tag = current_entity_tag
current_entity_tag = bilou_utils.tag_without_prefix(current_entity_tag)
else:
new_tag_found = (
last_entity_tag != current_entity_tag or group_or_role_changed
)
last_entity_tag = current_entity_tag
if new_tag_found:
# new entity found
entity = EntityExtractor._create_new_entity(
list(tags.keys()),
current_entity_tag,
current_group_tag,
current_role_tag,
token,
idx,
confidences,
)
entities.append(entity)
elif EntityExtractor._check_is_single_entity(
text, token, last_token_end, split_entities_config, current_entity_tag
):
# 当前令牌与之前的令牌具有相同的实体标签,并且两个令牌最多由 3 个符号分隔,其中每个符号必须是标点符号(例如“.”或“,”)和一个空格。
entities[-1][ENTITY_ATTRIBUTE_END] = token.end
if confidences is not None:
EntityExtractor._update_confidence_values(
entities, confidences, idx
)
else:
# 该令牌与之前的令牌具有相同的实体标签,但两个令牌至少由 2 个符号分隔(例如多个空格、逗号和空格等),并且也不应表示为单个实体
entity = EntityExtractor._create_new_entity(
list(tags.keys()),
current_entity_tag,
current_group_tag,
current_role_tag,
token,
idx,
confidences,
)
entities.append(entity)
last_group_tag = current_group_tag
last_role_tag = current_role_tag
last_token_end = token.end
for entity in entities:
entity[ENTITY_ATTRIBUTE_VALUE] = text[
entity[ENTITY_ATTRIBUTE_START] : entity[ENTITY_ATTRIBUTE_END]
]
return entities
@staticmethod
def _update_confidence_values(
entities: List[Dict[Text, Any]], confidences: Dict[Text, List[float]], idx: int
):
# use the lower confidence value
entities[-1][ENTITY_ATTRIBUTE_CONFIDENCE_TYPE] = min(
entities[-1][ENTITY_ATTRIBUTE_CONFIDENCE_TYPE],
confidences[ENTITY_ATTRIBUTE_TYPE][idx],
)
if ENTITY_ATTRIBUTE_ROLE in entities[-1]:
entities[-1][ENTITY_ATTRIBUTE_CONFIDENCE_ROLE] = min(
entities[-1][ENTITY_ATTRIBUTE_CONFIDENCE_ROLE],
confidences[ENTITY_ATTRIBUTE_ROLE][idx],
)
if ENTITY_ATTRIBUTE_GROUP in entities[-1]:
entities[-1][ENTITY_ATTRIBUTE_CONFIDENCE_GROUP] = min(
entities[-1][ENTITY_ATTRIBUTE_CONFIDENCE_GROUP],
confidences[ENTITY_ATTRIBUTE_GROUP][idx],
)
@staticmethod
def _check_is_single_entity(
text: Text,
token: Token,
last_token_end: int,
split_entities_config: Dict[Text, bool],
current_entity_tag: Text,
):
# 当前令牌与之前的令牌具有相同的实体标签,并且两个令牌最多仅由一个符号分隔(例如空格、破折号等)
if token.start - last_token_end <= 1:
return True
# Tokens 相距不超过 3 个位置
# 选择magic number 3 以便可以提取以下两种情况
# - Schönhauser Allee 175, 10119 Berlin (由 2 个标记 (", ") 分隔的地址组合)
# - 22 Powderhall Rd., EH7 4GB (缩写“Rd”导致3个标记(“.,”)的分离)
# 超过 3 个可能已经引入了这种逻辑不应考虑的情况
tokens_within_range = token.start - last_token_end <= 3
# 交互标记*must*是句号、逗号或空格
interleaving_text = text[last_token_end : token.start]
tokens_separated_by_allowed_chars = all(
filter(
lambda char: True
if char in SINGLE_ENTITY_ALLOWED_INTERLEAVING_CHARSET
else False,
interleaving_text,
)
)
# 当前实体类型必须与配置匹配(默认值为 True)
default_value = split_entities_config[SPLIT_ENTITIES_BY_COMMA]
split_current_entity_type = split_entities_config.get(
current_entity_tag, default_value
)
return (
tokens_within_range
and tokens_separated_by_allowed_chars
and not split_current_entity_type
)
@staticmethod
def get_tag_for(tags: Dict[Text, List[Text]], tag_name: Text, idx: int) -> Text:
"""从标签列表中获取给定标签名称的值."""
if tag_name in tags:
return tags[tag_name][idx]
return NO_ENTITY_TAG
@staticmethod
def _create_new_entity(
tag_names: List[Text],
entity_tag: Text,
group_tag: Text,
role_tag: Text,
token: Token,
idx: int,
confidences: Optional[Dict[Text, List[float]]] = None,
) -> Dict[Text, Any]:
"""创建新的Entity"""
entity = {
ENTITY_ATTRIBUTE_TYPE: entity_tag,
ENTITY_ATTRIBUTE_START: token.start,
ENTITY_ATTRIBUTE_END: token.end,
}
if confidences is not None:
entity[ENTITY_ATTRIBUTE_CONFIDENCE_TYPE] = confidences[
ENTITY_ATTRIBUTE_TYPE
][idx]
if ENTITY_ATTRIBUTE_ROLE in tag_names and role_tag != NO_ENTITY_TAG:
entity[ENTITY_ATTRIBUTE_ROLE] = role_tag
if confidences is not None:
entity[ENTITY_ATTRIBUTE_CONFIDENCE_ROLE] = confidences[
ENTITY_ATTRIBUTE_ROLE
][idx]
if ENTITY_ATTRIBUTE_GROUP in tag_names and group_tag != NO_ENTITY_TAG:
entity[ENTITY_ATTRIBUTE_GROUP] = group_tag
if confidences is not None:
entity[ENTITY_ATTRIBUTE_CONFIDENCE_GROUP] = confidences[
ENTITY_ATTRIBUTE_GROUP
][idx]
return entity
@staticmethod
def check_correct_entity_annotations(training_data: TrainingData) -> None:
"""检查实体是否在训练数据中正确注释。"""
for example in training_data.entity_examples: # entity_examples
entity_boundaries = [
(entity[ENTITY_ATTRIBUTE_START], entity[ENTITY_ATTRIBUTE_END])
for entity in example.get(ENTITIES)
]
token_start_positions = [t.start for t in example.get(TOKENS_NAMES[TEXT])] # 来自tokenizer后的分词记录text_token
token_end_positions = [t.end for t in example.get(TOKENS_NAMES[TEXT])]
for entity_start, entity_end in entity_boundaries:
if (
entity_start not in token_start_positions# 检测定义的分词边界entity_boundaries与代码jieba分词边界是否一致
or entity_end not in token_end_positions
):
entities_repr = [
(
entity[ENTITY_ATTRIBUTE_START],
entity[ENTITY_ATTRIBUTE_END],
entity[ENTITY_ATTRIBUTE_VALUE],
)
for entity in example.get(ENTITIES)
]
tokens_repr = [
(t.start, t.end, t.text)
for t in example.get(TOKENS_NAMES[TEXT])
]
rasa.shared.utils.io.raise_warning(
f"Misaligned entity annotation in message '{example.get(TEXT)}' "
f"with intent '{example.get(INTENT)}'. Make sure the start and "
f"end values of entities ({entities_repr}) in the training "
f"data match the token boundaries ({tokens_repr}). "
"Common causes: \n 1) entities include trailing whitespaces or punctuation"
"\n 2) the tokenizer gives an unexpected result, due to "
"languages such as Chinese that don't use whitespace for word separation",
docs=DOCS_URL_TRAINING_DATA_NLU,
)
break
rasa/nlu/extractors/crf_entity_extractor.py
class CRFEntityExtractor(EntityExtractor):
@classmethod
def required_components(cls) -> List[Type[Component]]:
return [Tokenizer]
defaults = {
# BILOU_flag 决定是否使用 BILOU 标记。
# 更严格但是每个实体需要更多的例子
# 经验法则:仅当每个实体超过 100 个 egs 时才使用。
BILOU_FLAG: True,
## 用逗号分割实体,这是有道理的,例如 对于接收者中的成分列表,但对于地址的部分没有意义
SPLIT_ENTITIES_BY_COMMA: True,
## crf_features 是 [before, token, after] 数组,其中包含 before、token、after 保存关于每个标记使用哪些特征的键,例如,数组中的“title” before 将具有“是标题大小写中的前一个标记?”的特征。
## POS 功能需要 SpacyTokenizer
## pattern 功能需要 RegexFeaturizer
"features": [
["low", "title", "upper"],
[
"low",
"bias",
"prefix5",
"prefix2",
"suffix5",
"suffix3",
"suffix2",
"upper",
"title",
"digit",
"pattern",
],
["low", "title", "upper"],
],
# 优化算法的最大迭代次数。
"max_iterations": 50,
# L1 正则化的权重
"L1_c": 0.1,
# L2 正则化的权重
"L2_c": 0.1,
# 要使用的密集特征的名称。
# 如果列表为空,则使用所有可用的密集特征。
"featurizers": [],
}
function_dict: Dict[Text, Callable[[CRFToken], Any]] = {
"low": lambda crf_token: crf_token.text.lower(),
"title": lambda crf_token: crf_token.text.istitle(),
"prefix5": lambda crf_token: crf_token.text[:5],
"prefix2": lambda crf_token: crf_token.text[:2],
"suffix5": lambda crf_token: crf_token.text[-5:],
"suffix3": lambda crf_token: crf_token.text[-3:],
"suffix2": lambda crf_token: crf_token.text[-2:],
"suffix1": lambda crf_token: crf_token.text[-1:],
"bias": lambda crf_token: "bias",
"pos": lambda crf_token: crf_token.pos_tag,
"pos2": lambda crf_token: crf_token.pos_tag[:2]
if crf_token.pos_tag is not None
else None,
"upper": lambda crf_token: crf_token.text.isupper(),
"digit": lambda crf_token: crf_token.text.isdigit(),
"pattern": lambda crf_token: crf_token.pattern,
"text_dense_features": lambda crf_token: crf_token.dense_features,
"entity": lambda crf_token: crf_token.entity_tag,
}
def __init__(
self,
component_config: Optional[Dict[Text, Any]] = None,
entity_taggers: Optional[Dict[Text, "CRF"]] = None,
) -> None:
super().__init__(component_config)
self.entity_taggers = entity_taggers
self.crf_order = [
ENTITY_ATTRIBUTE_TYPE,
ENTITY_ATTRIBUTE_ROLE,
ENTITY_ATTRIBUTE_GROUP,
]
self._validate_configuration()
self.split_entities_config = self.init_split_entities()
def _validate_configuration(self) -> None:
if len(self.component_config.get("features", [])) % 2 != 1:
raise ValueError(
"Need an odd number of crf feature lists to have a center word."
)
@classmethod
def required_packages(cls) -> List[Text]:
return ["sklearn_crfsuite", "sklearn"]
def train(
self,
training_data: TrainingData,
config: Optional[RasaNLUModelConfig] = None,
**kwargs: Any,
) -> None:
# checks whether there is at least one
# example with an entity annotation
if not training_data.entity_examples:
logger.debug(
"No training examples with entities present. Skip training"
"of 'CRFEntityExtractor'."
)
return
self.check_correct_entity_annotations(training_data) ## 校验 比如CRF前使用了结巴分词,那么结巴分词与自定义分词边界是否一致
if self.component_config[BILOU_FLAG]:
bilou_utils.apply_bilou_schema(training_data) ## 构建BILOU数据,存入Message对象中
# only keep the CRFs for tags we actually have training data for 只保留我们有训练数据的标签的crf
self._update_crf_order(training_data)
# filter out pre-trained entity examples 过滤掉预先训练的实体示例
entity_examples = self.filter_trainable_entities(training_data.nlu_examples)
dataset = [self._convert_to_crf_tokens(example) for example in entity_examples] ## 返回list [[CRFToken对象,...],[CRFToken对象,...]]
self._train_model(dataset)
def _update_crf_order(self, training_data: TrainingData):
"""Train only CRFs we actually have training data for. 仅训练我们实际拥有训练数据的 CRF"""
_crf_order = []
for tag_name in self.crf_order: #['entity', 'role', 'group']
if tag_name == ENTITY_ATTRIBUTE_TYPE and training_data.entities:
_crf_order.append(ENTITY_ATTRIBUTE_TYPE)
elif tag_name == ENTITY_ATTRIBUTE_ROLE and training_data.entity_roles:
_crf_order.append(ENTITY_ATTRIBUTE_ROLE)
elif tag_name == ENTITY_ATTRIBUTE_GROUP and training_data.entity_groups:
_crf_order.append(ENTITY_ATTRIBUTE_GROUP)
self.crf_order = _crf_order
def process(self, message: Message, **kwargs: Any) -> None:
entities = self.extract_entities(message)
entities = self.add_extractor_name(entities)
message.set(ENTITIES, message.get(ENTITIES, []) + entities, add_to_output=True)
def extract_entities(self, message: Message) -> List[Dict[Text, Any]]:
"""Extract entities from the given message using the trained model(s)."""
if self.entity_taggers is None:
return []
tokens = message.get(TOKENS_NAMES[TEXT])
crf_tokens = self._convert_to_crf_tokens(message)
predictions = {}
for tag_name, entity_tagger in self.entity_taggers.items():
# use predicted entity tags as features for second level CRFs
include_tag_features = tag_name != ENTITY_ATTRIBUTE_TYPE
if include_tag_features:
self._add_tag_to_crf_token(crf_tokens, predictions)
# [{'BOS': True, '0:low': '北京', '0:bias': 'bias', '0:prefix5': '北京', '0:prefix2': '北京', '0:suffix5': '北京', '0:suffix3': '北京', '0:suffix2': '北京', '0:upper': False, '0:title': False, '0:digit': False, '1:low': '的', '1:title': False, '1:upper': False}]
features = self._crf_tokens_to_features(crf_tokens, include_tag_features)
predictions[tag_name] = entity_tagger.predict_marginals_single(features)
"""
CRF预测结果 {'entity': [{'U-date': 0.0018662581715369866, 'U-location': 0.9534361235771649, 'O': 0.04310048830917551, 'B-location': 0.00040871033087749956, 'I-location': 0.0003764602861231511, 'L-location': 0.000811959325121927},
{'U-date': 0.005369950688885148, 'U-location': 0.004552219783974233, 'O': 0.9834001333595512, 'B-location': 0.0020663145611778053, 'I-location': 0.002047266321012076, 'L-location': 0.002564115285399473},
{'U-date': 0.005864643891460612, 'U-location': 0.005886398421449579, 'O': 0.9790798152952706, 'B-location': 0.0032587770383554573, 'I-location': 0.0023428997898938143, 'L-location': 0.00356746556357009},
{'U-date': 0.013875775748399285, 'U-location': 0.0142575192728364, 'O': 0.9510489275011496, 'B-location': 0.007400299042951769, 'I-location': 0.004890205794937807, 'L-location': 0.008527272639725128},
{'U-date': 0.00491783928007722, 'U-location': 0.004268236726299186, 'O': 0.9804322682816534, 'B-location': 0.0036810716731708877, 'I-location': 0.0036080863158401088, 'L-location': 0.0030924977229592533}]}
"""
# convert predictions into a list of tags and a list of confidences
tags, confidences = self._tag_confidences(tokens, predictions)
return self.convert_predictions_into_entities(
message.get(TEXT), tokens, tags, self.split_entities_config, confidences
)
def _add_tag_to_crf_token(
self,
crf_tokens: List[CRFToken],
predictions: Dict[Text, List[Dict[Text, float]]],
):
"""Add predicted entity tags to CRF tokens."""
if ENTITY_ATTRIBUTE_TYPE in predictions:
_tags, _ = self._most_likely_tag(predictions[ENTITY_ATTRIBUTE_TYPE])
for tag, token in zip(_tags, crf_tokens):
token.entity_tag = tag
def _most_likely_tag(
self, predictions: List[Dict[Text, float]]
) -> Tuple[List[Text], List[float]]:
"""获取置信度最高的实体标签。 """
_tags = []
_confidences = []
for token_predictions in predictions:
tag = max(token_predictions, key=lambda key: token_predictions[key])
_tags.append(tag)
if self.component_config[BILOU_FLAG]:
# 如果我们使用 BILOU 标志,我们将总结实体的 B、I、L 和 U 标签的概率
_confidences.append(
sum(
_confidence
for _tag, _confidence in token_predictions.items()
if bilou_utils.tag_without_prefix(tag)
== bilou_utils.tag_without_prefix(_tag)
)
)
else:
_confidences.append(token_predictions[tag])
return _tags, _confidences
def _tag_confidences(
self, tokens: List[Token], predictions: Dict[Text, List[Dict[Text, float]]]
) -> Tuple[Dict[Text, List[Text]], Dict[Text, List[float]]]:
"""使用标记的置信度值获取最有可能的标签预测。"""
tags = {}
confidences = {}
for tag_name, predicted_tags in predictions.items():
if len(tokens) != len(predicted_tags):
raise Exception(
"Inconsistency in amount of tokens between crfsuite and message"
)
_tags, _confidences = self._most_likely_tag(predicted_tags)
if self.component_config[BILOU_FLAG]:
_tags, _confidences = bilou_utils.ensure_consistent_bilou_tagging(
_tags, _confidences
)
confidences[tag_name] = _confidences
tags[tag_name] = _tags
return tags, confidences
@classmethod
def load(
cls,
meta: Dict[Text, Any],
model_dir: Text = None,
model_metadata: Metadata = None,
cached_component: Optional["CRFEntityExtractor"] = None,
**kwargs: Any,
) -> "CRFEntityExtractor":
import joblib
file_names = meta.get("files")
entity_taggers = {}
if not file_names:
logger.debug(
f"Failed to load model for 'CRFEntityExtractor'. "
f"Maybe you did not provide enough training data and no model was "
f"trained or the path '{os.path.abspath(model_dir)}' doesn't exist?"
)
return cls(component_config=meta)
for name, file_name in file_names.items():
model_file = os.path.join(model_dir, file_name)
if os.path.exists(model_file):
entity_taggers[name] = joblib.load(model_file)
else:
logger.debug(
f"Failed to load model for tag '{name}' for 'CRFEntityExtractor'. "
f"Maybe you did not provide enough training data and no model was "
f"trained or the path '{os.path.abspath(model_file)}' doesn't "
f"exist?"
)
return cls(meta, entity_taggers)
def persist(self, file_name: Text, model_dir: Text) -> Optional[Dict[Text, Any]]:
"""将此模型保存到传递的目录中。"""
import joblib
file_names = {}
if self.entity_taggers:
for name, entity_tagger in self.entity_taggers.items():
file_name = f"{file_name}.{name}.pkl"
model_file_name = os.path.join(model_dir, file_name)
joblib.dump(entity_tagger, model_file_name)
file_names[name] = file_name
return {"files": file_names}
def _crf_tokens_to_features(
self, crf_tokens: List[CRFToken], include_tag_features: bool = False
) -> List[Dict[Text, Any]]:
"""将标记列表转换为离散特征"""
configured_features = self.component_config["features"]
sentence_features = []
for token_idx in range(len(crf_tokens)):
# the features for the current token include features of the token # 当前令牌的特征包括令牌的特征
# before and after the current features (if defined in the config) # 当前功能之前和之后(如果在配置中定义)
# token before (-1), current token (0), token after (+1) # 标记之前(-1),当前标记(0),标记之后(+1)
window_size = len(configured_features)
half_window_size = window_size // 2
window_range = range(-half_window_size, half_window_size + 1)
token_features = self._create_features_for_token(
crf_tokens,
token_idx,
half_window_size,
window_range,
include_tag_features,
) ## 构建训练特征{'BOS': True, '0:low': '今天', '0:bias': 'bias', '0:prefix5': '今天', '0:prefix2': '今天', '0:suffix5': '今天', '0:suffix3': '今天', '0:suffix2': '今天', '0:upper': False, '0:title': False, '0:digit': False, '1:low': '上海', '1:title': False, '1:upper': False}
sentence_features.append(token_features)
return sentence_features
def _create_features_for_token(
self,
crf_tokens: List[CRFToken],
token_idx: int,
half_window_size: int,
window_range: range,
include_tag_features: bool,
):
"""将标记转换为离散特征,包括词前和词后。"""
configured_features = self.component_config["features"] # [['low', 'title', 'upper'], ['low', 'bias', 'prefix5', 'prefix2', 'suffix5', 'suffix3', 'suffix2', 'upper', 'title', 'digit', 'pattern'], ['low', 'title', 'upper']]
prefixes = [str(i) for i in window_range] #[-1, 0, 1]
token_features = {}
# 迭代窗口范围 (-1, 0, +1) 中的标记以收集 token_idx 处标记的特征
for pointer_position in window_range:
current_token_idx = token_idx + pointer_position
if current_token_idx >= len(crf_tokens):
# 记号在句尾
token_features["EOS"] = True
elif current_token_idx < 0:
# 标记在句子的开头
token_features["BOS"] = True
else:
token = crf_tokens[current_token_idx]
# 获取要为我们当前正在查看的令牌提取的特征
current_feature_idx = pointer_position + half_window_size
features = configured_features[current_feature_idx]
prefix = prefixes[current_feature_idx]
#我们添加了“实体”功能,将实体类型包含为role和group CRF功能(不要修改功能,否则我们最终会一遍又一遍地添加“实体”,使训练变得非常缓慢)
additional_features = []
if include_tag_features:
additional_features.append("entity")
for feature in features + additional_features:
if feature == "pattern":
## 添加从“RegexFeaturizer”中提取的所有正则表达式作为特征:“pattern_name”是用户在训练数据中设置的模式的名称,“matched”是“True”或“False”,具体取决于令牌是否实际匹配 模式与否
regex_patterns = self.function_dict[feature](token)
for pattern_name, matched in regex_patterns.items():
token_features[
f"{prefix}:{feature}:{pattern_name}"
] = matched
else:
value = self.function_dict[feature](token)
token_features[f"{prefix}:{feature}"] = value
return token_features
@staticmethod
def _crf_tokens_to_tags(crf_tokens: List[CRFToken], tag_name: Text) -> List[Text]:
"""Return the list of tags for the given tag name. 返回给定标签名称的标签列表"""
if tag_name == ENTITY_ATTRIBUTE_ROLE:
return [crf_token.entity_role_tag for crf_token in crf_tokens]
if tag_name == ENTITY_ATTRIBUTE_GROUP:
return [crf_token.entity_group_tag for crf_token in crf_tokens]
return [crf_token.entity_tag for crf_token in crf_tokens]
@staticmethod
def _pattern_of_token(message: Message, idx: int) -> Dict[Text, bool]:
"""获取 'RegexFeaturizer 提取的给定索引处的令牌模式
“RegexFeaturizer”将训练数据中列出的所有pattern 添加到令牌中。 pattern名称映射到“true”(pattern 适用于令牌)或“false”(pattern 不适用于令牌)。
"""
if message.get(TOKENS_NAMES[TEXT]) is not None:
return message.get(TOKENS_NAMES[TEXT])[idx].get("pattern", {})
return {}
def _get_dense_features(self, message: Message) -> Optional[List]:
"""将密集特征转换为 python-crfsuite 特征格式。"""
features, _ = message.get_dense_features(
TEXT, self.component_config["featurizers"]
)# messages中self.features的值
if features is None:
return None
tokens = message.get(TOKENS_NAMES[TEXT])
if len(tokens) != len(features.features):
rasa.shared.utils.io.raise_warning(
f"Number of dense features ({len(features.features)}) for attribute "
f"'TEXT' does not match number of tokens ({len(tokens)}).",
docs=DOCS_URL_COMPONENTS + "#crfentityextractor",
)
return None
# convert to python-crfsuite feature format
features_out = []
for feature in features.features:
feature_dict = {
str(index): token_features
for index, token_features in enumerate(feature)
}
converted = {"text_dense_features": feature_dict}
features_out.append(converted)
return features_out
def _convert_to_crf_tokens(self, message: Message) -> List[CRFToken]:
"""获取消息并将其转换为 crfsuite 格式。"""
# 将消息转换成crfsuite格式
crf_format = []
tokens = message.get(TOKENS_NAMES[TEXT])
text_dense_features = self._get_dense_features(message) #将密集特性转换为python-crfsuite特性格式
tags = self._get_tags(message) # ['U-date', 'U-location', 'O', 'O', 'O']
for i, token in enumerate(tokens):
pattern = self._pattern_of_token(message, i)
entity = self.get_tag_for(tags, ENTITY_ATTRIBUTE_TYPE, i)
group = self.get_tag_for(tags, ENTITY_ATTRIBUTE_GROUP, i)
role = self.get_tag_for(tags, ENTITY_ATTRIBUTE_ROLE, i)
pos_tag = token.get(POS_TAG_KEY)
dense_features = (
text_dense_features[i] if text_dense_features is not None else []
)
crf_format.append(
CRFToken(
text=token.text,
pos_tag=pos_tag,
entity_tag=entity,
entity_group_tag=group,
entity_role_tag=role,
pattern=pattern,
dense_features=dense_features,
)
)
return crf_format # [CRFToken(),...]
def _get_tags(self, message: Message) -> Dict[Text, List[Text]]:
"""获取分配的消息实体标签。"""
tokens = message.get(TOKENS_NAMES[TEXT])
tags = {}
for tag_name in self.crf_order:
if self.component_config[BILOU_FLAG]:
bilou_key = bilou_utils.get_bilou_key_for_tag(tag_name)
if message.get(bilou_key):
_tags = message.get(bilou_key)
else:
_tags = [NO_ENTITY_TAG for _ in tokens]
else:
_tags = [
determine_token_labels(
token, message.get(ENTITIES), attribute_key=tag_name
)
for token in tokens
]
tags[tag_name] = _tags
return tags
def _train_model(self, df_train: List[List[CRFToken]]) -> None:
"""Train the crf tagger based on the training data."""
import sklearn_crfsuite
self.entity_taggers = {}
for tag_name in self.crf_order:
logger.debug(f"Training CRF for '{tag_name}'.")
# add entity tag features for second level CRFs
include_tag_features = tag_name != ENTITY_ATTRIBUTE_TYPE
X_train = [
self._crf_tokens_to_features(sentence, include_tag_features)
for sentence in df_train
]## 构建训练CRF特征 [[{'BOS': True, '0:low': '今天', '0:bias': 'bias', '0:prefix5': '今天', '0:prefix2': '今天', '0:suffix5': '今天', '0:suffix3': '今天', '0:suffix2': '今天', '0:upper': False, '0:title': False, '0:digit': False, '1:low': '上海', '1:title': False, '1:upper': False}, {'-1:low': '今天', '-1:title': False, '-1:upper': False, '0:low': '上海', '0:bias': 'bias', '0:prefix5': '上海', '0:prefix2': '上海', '0:suffix5': '上海', '0:suffix3': '上海', '0:suffix2': '上海', '0:upper': False, '0:title': False, '0:digit': False, '1:low': '的', '1:title': False, '1:upper': False}, {'-1:low': '上海', '-1:title': False, '-1:upper': False, '0:low': '的', '0:bias': 'bias', '0:prefix5': '的', '0:prefix2': '的', '0:suffix5': '的', '0:suffix3': '的', '0:suffix2': '的', '0:upper': False, '0:title': False, '0:digit': False, '1:low': '天气', '1:title': False, '1:upper': False}, {'-1:low': '的', '-1:title': False, '-1:upper': False, '0:low': '天气', '0:bias': 'bias', '0:prefix5': '天气', '0:prefix2': '天气', '0:suffix5': '天气', '0:suffix3': '天气', '0:suffix2': '天气', '0:upper': False, '0:title': False, '0:digit': False, '1:low': '怎么样', '1:title': False, '1:upper': False}, {'-1:low': '天气', '-1:title': False, '-1:upper': False, '0:low': '怎么样', '0:bias': 'bias', '0:prefix5': '怎么样', '0:prefix2': '怎么', '0:suffix5': '怎么样', '0:suffix3': '怎么样', '0:suffix2': '么样', '0:upper': False, '0:title': False, '0:digit': False, 'EOS': True}]]
y_train = [
self._crf_tokens_to_tags(sentence, tag_name) for sentence in df_train
] #[['U-date', 'U-location', 'O', 'O', 'O']]
entity_tagger = sklearn_crfsuite.CRF(
algorithm="lbfgs",
# coefficient for L1 penalty
c1=self.component_config["L1_c"],# l1正则
# coefficient for L2 penalty
c2=self.component_config["L2_c"], # l2正则
# stop earlier
max_iterations=self.component_config["max_iterations"], # 早停机制
# include transitions that are possible, but not observed
all_possible_transitions=True,
)
entity_tagger.fit(X_train, y_train)
self.entity_taggers[tag_name] = entity_tagger
logger.debug("Training finished.")
rasa/nlu/utils/bilou_utils.py
将token标记为BILOU的标注
def apply_bilou_schema_to_message(message: "Message") -> None:
"""获取 BILOU 实体标签列表并将它们设置在给定的消息上。"""
entities = message.get(ENTITIES)
if not entities:
return
tokens = message.get(TOKENS_NAMES[TEXT])
for attribute, message_key in [
(ENTITY_ATTRIBUTE_TYPE, BILOU_ENTITIES), # entity bilou_entities
(ENTITY_ATTRIBUTE_ROLE, BILOU_ENTITIES_ROLE), # role bilou_entities_role
(ENTITY_ATTRIBUTE_GROUP, BILOU_ENTITIES_GROUP), # group bilou_entities_group
]:
entities = map_message_entities(message, attribute) #[(0, 2, 'date'), (2, 4, 'location')]
output = bilou_tags_from_offsets(tokens, entities) # BILOU [U-date, U-location, O, O, O] ROLE[O, O, O, O, O] GROUP[O, O, O, O, O]
message.set(message_key, output)