Byte Pair Encoding (BPE) 算法的核心实现主要涉及以下几个步骤:
- 统计字符对频率:
- 遍历训练语料库中的所有词或句子,统计相邻字符对的出现频率。这通常是通过滑动窗口在文本数据上移动,每次移动一个字符,然后记录相邻字符对的出现次数。
def get_pairs(self, word): # 用于存储字符对的列表 pairs = [] # 初始化前一个字符为单词的第一个字符 prev_char = word[0] # 遍历单词中的剩余字符 for char in word[1:]: # 将前一个字符和当前字符组成的字符对添加到列表中 pairs.append((prev_char, char)) # 更新前一个字符为当前字符 prev_char = char # 返回字符对列表 return pairs pairs = defaultdict(int) for word in text.split(): # 获取单词中的所有字符对 word_pairs = self.get_pairs(word) # 统计每个字符对的出现次数 for pair in word_pairs: pairs[pair] += 1
- 遍历训练语料库中的所有词或句子,统计相邻字符对的出现频率。这通常是通过滑动窗口在文本数据上移动,每次移动一个字符,然后记录相邻字符对的出现次数。
- 选择最频繁字符对:
- 根据统计的频率,选择出现次数最多的字符对。在BPE算法的每一步中,都会选择当前出现频率最高的字符对进行合并。
- 合并字符对:
- 将选出的最频繁字符对合并成一个新的“字符”(通常是一个新的标记或符号)。这个新字符代表了一个常见的字符组合,可以被视为一个新的“词素”。
- 更新词汇和频率:
- 在语料库中,将所有选定的字符对替换为新合并的字符,并更新字符频率统计。这意味着在后续的迭代中,新合并的字符将作为一个整体被考虑。
- 迭代过程:
- 重复步骤2-4,直到达到预定的词汇表大小,或者没有更多的字符对可以合并(即所有字符对的出现次数都为1)。
for _ in range(num_merges): # 如果没有字符对可以合并,则退出循环 if not pairs: break # 找到出现频率最高的字符对 most_common = max(pairs, key=pairs.get) # 将该字符对添加到合并列表中 self.merges.append(most_common) # 将该字符对的频率重置为0,表示已经合并过 pairs[most_common] = 0 # 遍历文本中的每个单词,进行合并操作 for word in text.split(): word_pairs = self.get_pairs(word) # 如果单词中包含需要合并的字符对,则进行合并 if most_common in word_pairs: new_word = re.sub(f'(.)({most_common[1]})', lambda m: m.group(1) + most_common[0] + most_common[1] if m.group(1) + m.group(2) == most_common else m.group(0), word) # 用合并后的新单词替换原文本中的单词 text = text.replace(word, new_word) # 更新词汇频率统计 self.vocab[most_common[0] + most_common[1]] = self.vocab.get(most_common[0] + most_common[1], 0) + 1
- 重复步骤2-4,直到达到预定的词汇表大小,或者没有更多的字符对可以合并(即所有字符对的出现次数都为1)。
- 编码和解码:
- 编码时,将输入的文本根据最终生成的词汇表进行分词(即将文本中的字符或字符组合替换为对应的BPE标记)。
- 解码时,将BPE标记还原为原始的字符序列。
def encode(self, text): # 如果文本已经在缓存中,则直接返回缓存中的编码结果 if text in self.cache: return self.cache[text] # 将文本拆分为单词列表 tokens = text.split() # 用于存储编码后的单词列表 encoded_tokens = [] # 遍历每个单词进行编码 for token in tokens: encoded = [] # 当单词长度大于1时,继续进行编码 while len(token) > 1: pair = None # 遍历合并列表,从后往前查找需要合并的字符对 for merge in reversed(self.merges): if merge[0] + merge[1] in token: pair = merge break # 如果找到了需要合并的字符对,则进行合并操作,并添加编码标记'@@' if pair: token = token.replace(pair[0] + pair[1], pair[0] + '@@' + pair[1]) else: # 如果没有找到需要合并的字符对,则将单词的第一个字符添加到编码列表中,并删除该字符 encoded.append(token[0]) token = token[1:] # 将剩余的字符添加到编码列表中 encoded.extend(token) # 将编码后的单词添加到编码单词列表中,单词内部的字符之间用空格隔开 encoded_tokens.append(' '.join(encoded)) # 将编码后的单词列表合并为一个字符串,单词之间用空格隔开 encoded_text = ' '.join(encoded_tokens) # 将编码结果添加到缓存中 self.cache[text] = encoded_text # 返回编码结果 return encoded_text def decode(self, encoded_text): # 将编码文本拆分为单词列表 tokens = encoded_text.split() # 用于存储解码后的单词列表 decoded_tokens = [] # 遍历每个编码后的单词进行解码 for token in tokens: decoded = [] # 将单词拆分为由'@@'分隔的子串列表 subtokens = token.split('@@') # 遍历每个子串进行解码 for subtoken in subtokens: # 如果子串以'@@'结尾,则去除'@@'并将其添加到解码列表中 if subtoken.endswith('@@'): decoded.append(subtoken[:-2]) else: # 否则直接将子串添加到解码列表中 decoded.append(subtoken) # 将解码后的字符列表合并为一个字符串 decoded_str = ''.join(decoded) # 将解码后的单词添加到解码单词列表中 decoded_tokens.append(decoded_str) # 将解码后的单词列表合并为一个字符串,单词之间用空格隔开 return ' '.join(decoded_tokens)
BPE算法的关键在于它通过迭代地合并最频繁的字符对来构建一个新的、更有效的词汇表。这种方法能够有效地处理不在初始词汇表中的词(OOV词),因为它可以动态地创建新的词汇单元来表示这些未知词。同时,BPE还可以帮助减少词汇表的大小,提高模型的泛化能力,并有助于处理形态丰富的语言。