环境:Python 3.7.9
Rasa 2.0.6
Rasa SDK 2.0.0
一、问题
博主在使用Rasa做中文问答时遇到了一个问题:添加form,slot filling使用from_entity,并在pipeline中添加RegexEntityExtractor
。假设该entity为city
,在nlu.yml中仅添加了郑州
作为training data,且在nlu.yml中添加了city
的lookup table,如图。
但在实际对话中,除了郑州
可以被DIETClassifier
识别到,lookup table中没有出现在training data中的的例子均无法正常auto fill,如图。
二、分析
为什么RegexEntityExtractor
无法识别lookup table中的例子呢?在官方文档查询无果,于是果断从源码入手,分析RegexEntityExtractor
。
# regex_entity_extractor.py
import rasa.nlu.utils.pattern_utils as pattern_utils
...
class RegexEntityExtractor(EntityExtractor):
"""Searches for entities in the user's message using the lookup tables and regexes
defined in the training data."""
...
def train(
self,
training_data: TrainingData,
config: Optional[RasaNLUModelConfig] = None,
**kwargs: Any,
) -> None:
self.patterns = pattern_utils.extract_patterns(
training_data,
use_lookup_tables=self.component_config["use_lookup_tables"],
use_regexes=self.component_config["use_regexes"],
use_only_entities=True,
)
if not self.patterns:
rasa.shared.utils.io.raise_warning(
"No lookup tables or regexes defined in the training data that have "
"a name equal to any entity in the training data. In order for this "
"component to work you need to define valid lookup tables or regexes "
"in the training data."
)
...
def _extract_entities(self, message: Message) -> List[Dict[Text, Any]]:
"""Extract entities of the given type from the given user message."""
...
for pattern in self.patterns:
matches = re.finditer(pattern["pattern"], message.get(TEXT), flags=flags)
matches = list(matches)
for match in matches:
start_index = match.start()
end_index = match.end()
entities.append(
{
ENTITY_ATTRIBUTE_TYPE: pattern["name"],
ENTITY_ATTRIBUTE_START: start_index,
ENTITY_ATTRIBUTE_END: end_index,
ENTITY_ATTRIBUTE_VALUE: message.get(TEXT)[
start_index:end_index
],
}
)
return entities
...
这里只列出两个最重要的方法。
可以看到,在_extract_entities
方法中RegexEntityExtractor
使用了self.patterns
中的正则表达式对用户的输入进行匹配,在train
方法中可以看到,self.patterns
是通过调用了pattern_utils
的extract_patterns
方法得到的,于是继续追踪。
# pattern_utils.py
def _convert_lookup_tables_to_regex(
training_data: TrainingData, use_only_entities: bool = False
) -> List[Dict[Text, Text]]:
"""Convert the lookup tables from the training data to regex patterns.
Args:
training_data: The training data.
use_only_entities: If True only regex features with a name equal to a entity
are considered.
Returns:
A list of regex patterns.
"""
patterns = []
for table in training_data.lookup_tables:
if use_only_entities and table["name"] not in training_data.entities:
continue
regex_pattern = _generate_lookup_regex(table)
lookup_regex = {"name": table["name"], "pattern": regex_pattern}
patterns.append(lookup_regex)
return patterns
def _generate_lookup_regex(lookup_table: Dict[Text, Union[Text, List[Text]]]) -> Text:
"""Creates a regex pattern from the given lookup table.
The lookup table is either a file or a list of entries.
Args:
lookup_table: The lookup table.
Returns:
The regex pattern.
"""
lookup_elements = lookup_table["elements"]
# if it's a list, it should be the elements directly
if isinstance(lookup_elements, list):
elements_to_regex = lookup_elements
# otherwise it's a file path.
else:
elements_to_regex = read_lookup_table_file(lookup_elements)
# sanitize the regex, escape special characters
elements_sanitized = [re.escape(e) for e in elements_to_regex]
# regex matching elements with word boundaries on either side
return "(\\b" + "\\b|\\b".join(elements_sanitized) + "\\b)"
...
def extract_patterns(
training_data: TrainingData,
use_lookup_tables: bool = True,
use_regexes: bool = True,
use_only_entities: bool = False,
) -> List[Dict[Text, Text]]:
"""Extract a list of patterns from the training data.
The patterns are constructed using the regex features and lookup tables defined
in the training data.
Args:
training_data: The training data.
use_only_entities: If True only lookup tables and regex features with a name
equal to a entity are considered.
use_regexes: Boolean indicating whether to use regex features or not.
use_lookup_tables: Boolean indicating whether to use lookup tables or not.
Returns:
The list of regex patterns.
"""
if not training_data.lookup_tables and not training_data.regex_features:
return []
patterns = []
if use_regexes:
patterns.extend(_collect_regex_features(training_data, use_only_entities))
if use_lookup_tables:
patterns.extend(
_convert_lookup_tables_to_regex(training_data, use_only_entities)
)
return patterns
这里只列出三个最重要的方法。
可以看到,在extract_patterns
方法中会判断用户是否开启了use_lookup_tables
选项,如果启用,则调用_convert_lookup_tables_to_regex
方法,即将lookup table转换为regex。在官方文档中,我们也可以看到,查找表是需要转换为正则表达式进行匹配的:
Lookup tables are lists of words used to generate case-insensitive regular expression patterns.
查找表是用来生成大小写敏感的正则表达式的单词的列表。
继续追踪,在_convert_lookup_tables_to_regex
方法中可以看到,正则表达式又是调用_generate_lookup_regex
方法生成的。最终,我们来到了_generate_lookup_regex
方法,发现了事情的真相。直接看return的部分,我们发现,返回的正则表达式并不是简单地将查找表中的例子用|
连接起来,而是在每个例子前后都加上了一个\b
,而这个\b
就是问题的关键。经过搜索(博主并不擅长正则表达式,见谅),原来\b
是为了在匹配时只匹配边界的例子,如er\b
可以匹配never
中的er
,但不能匹配verb
中的er
,而中文的单词间并没有空格,导致句子中的例子无法被识别。
三、解决
真相大白,只需将rasa/nlu/utils/pattern_utils.py
中_generate_lookup_regex
方法中的返回值中的\\b
删去,即可得到适合中文的RegexEntityExtractor
。
但这样会导致另一个问题:如果DIETClassifier
已经识别出某一城市,而RegexEntityExtractor
根据lookup table又识别了一次该城市,就会出现如下情况:
所以在修改完pattern_utils.py
之后,还需要对rasa/nlu/extractors/regex_entity_extractor.py
的_extract_entities
方法进行如下修改,使同一个城市不会被多次识别。
def _extract_entities(self, message: Message) -> List[Dict[Text, Any]]:
"""Extract entities of the given type from the given user message."""
entities = []
# 已经被识别到的entities
extracted_entities = [m["value"] for m in message.get(ENTITIES, [])]
flags = 0 # default flag
if not self.case_sensitive:
flags = re.IGNORECASE
for pattern in self.patterns:
matches = re.finditer(pattern["pattern"], message.get(TEXT), flags=flags)
matches = list(matches)
for match in matches:
start_index = match.start()
end_index = match.end()
# 如果该entity已经被识别到,则不添加
if match.group() in extracted_entities:
continue
entities.append(
{
ENTITY_ATTRIBUTE_TYPE: pattern["name"],
ENTITY_ATTRIBUTE_START: start_index,
ENTITY_ATTRIBUTE_END: end_index,
ENTITY_ATTRIBUTE_VALUE: message.get(TEXT)[
start_index:end_index
],
}
)
return entities
四、解决?
正当博主以为真相已经水落石出之时,却偶然发现,北京、上海等大城市并没有出现该问题,可以被DIETClassifier
正常auto fill,甚至不添加至lookup table也可以识别…目测可能和博主使用的预训练模型有关,有待进一步求证。