Byte Pair Encoding (BPE) 算法的核心实现

Byte Pair Encoding (BPE) 算法的核心实现主要涉及以下几个步骤:

  • 统计字符对频率
    • 遍历训练语料库中的所有词或句子,统计相邻字符对的出现频率。这通常是通过滑动窗口在文本数据上移动,每次移动一个字符,然后记录相邻字符对的出现次数。
      def get_pairs(self, word):  
          # 用于存储字符对的列表  
          pairs = []  
          # 初始化前一个字符为单词的第一个字符  
          prev_char = word[0]  
          # 遍历单词中的剩余字符  
          for char in word[1:]:  
              # 将前一个字符和当前字符组成的字符对添加到列表中  
              pairs.append((prev_char, char))  
              # 更新前一个字符为当前字符  
              prev_char = char  
          # 返回字符对列表  
          return pairs  
      
      
      pairs = defaultdict(int)  
      for word in text.split():  
          # 获取单词中的所有字符对  
          word_pairs = self.get_pairs(word)  
          # 统计每个字符对的出现次数  
          for pair in word_pairs:  
              pairs[pair] += 1  
      
  • 选择最频繁字符对
    • 根据统计的频率,选择出现次数最多的字符对。在BPE算法的每一步中,都会选择当前出现频率最高的字符对进行合并。
  • 合并字符对
    • 将选出的最频繁字符对合并成一个新的“字符”(通常是一个新的标记或符号)。这个新字符代表了一个常见的字符组合,可以被视为一个新的“词素”。
  • 更新词汇和频率
    • 在语料库中,将所有选定的字符对替换为新合并的字符,并更新字符频率统计。这意味着在后续的迭代中,新合并的字符将作为一个整体被考虑。
  • 迭代过程
    • 重复步骤2-4,直到达到预定的词汇表大小,或者没有更多的字符对可以合并(即所有字符对的出现次数都为1)。
      for _ in range(num_merges):  
          # 如果没有字符对可以合并,则退出循环  
          if not pairs:  
              break  
          # 找到出现频率最高的字符对  
          most_common = max(pairs, key=pairs.get)  
          # 将该字符对添加到合并列表中  
          self.merges.append(most_common)  
          # 将该字符对的频率重置为0,表示已经合并过  
          pairs[most_common] = 0  
      
          # 遍历文本中的每个单词,进行合并操作  
          for word in text.split():  
              word_pairs = self.get_pairs(word)  
              # 如果单词中包含需要合并的字符对,则进行合并  
              if most_common in word_pairs:  
                  new_word = re.sub(f'(.)({most_common[1]})',   
                      lambda m: m.group(1) + most_common[0] + most_common[1]   
                      if m.group(1) + m.group(2) == most_common else m.group(0),   
                      word)  
                  # 用合并后的新单词替换原文本中的单词  
                  text = text.replace(word, new_word)  
                  # 更新词汇频率统计  
                  self.vocab[most_common[0] + most_common[1]] = self.vocab.get(most_common[0] + most_common[1], 0) + 1  
      
  • 编码和解码
    • 编码时,将输入的文本根据最终生成的词汇表进行分词(即将文本中的字符或字符组合替换为对应的BPE标记)。
    • 解码时,将BPE标记还原为原始的字符序列。
      def encode(self, text):  
          # 如果文本已经在缓存中,则直接返回缓存中的编码结果  
          if text in self.cache:  
              return self.cache[text]  
      
          # 将文本拆分为单词列表  
          tokens = text.split()  
          # 用于存储编码后的单词列表  
          encoded_tokens = []  
          # 遍历每个单词进行编码  
          for token in tokens:  
              encoded = []  
              # 当单词长度大于1时,继续进行编码  
              while len(token) > 1:  
                  pair = None  
                  # 遍历合并列表,从后往前查找需要合并的字符对  
                  for merge in reversed(self.merges):  
                      if merge[0] + merge[1] in token:  
                          pair = merge  
                          break  
                  # 如果找到了需要合并的字符对,则进行合并操作,并添加编码标记'@@'  
                  if pair:  
                      token = token.replace(pair[0] + pair[1], pair[0] + '@@' + pair[1])  
                  else:  
                      # 如果没有找到需要合并的字符对,则将单词的第一个字符添加到编码列表中,并删除该字符  
                      encoded.append(token[0])  
                      token = token[1:]  
              # 将剩余的字符添加到编码列表中  
              encoded.extend(token)  
              # 将编码后的单词添加到编码单词列表中,单词内部的字符之间用空格隔开  
              encoded_tokens.append(' '.join(encoded))  
      
          # 将编码后的单词列表合并为一个字符串,单词之间用空格隔开  
          encoded_text = ' '.join(encoded_tokens)  
          # 将编码结果添加到缓存中  
          self.cache[text] = encoded_text  
          # 返回编码结果  
          return encoded_text  
      
      def decode(self, encoded_text):  
          # 将编码文本拆分为单词列表  
          tokens = encoded_text.split()  
          # 用于存储解码后的单词列表  
          decoded_tokens = []  
          # 遍历每个编码后的单词进行解码  
          for token in tokens:  
              decoded = []  
              # 将单词拆分为由'@@'分隔的子串列表  
              subtokens = token.split('@@')  
              # 遍历每个子串进行解码  
              for subtoken in subtokens:  
                  # 如果子串以'@@'结尾,则去除'@@'并将其添加到解码列表中  
                  if subtoken.endswith('@@'):  
                      decoded.append(subtoken[:-2])  
                  else:  
                      # 否则直接将子串添加到解码列表中  
                      decoded.append(subtoken)  
              # 将解码后的字符列表合并为一个字符串  
              decoded_str = ''.join(decoded)  
              # 将解码后的单词添加到解码单词列表中  
              decoded_tokens.append(decoded_str)  
          # 将解码后的单词列表合并为一个字符串,单词之间用空格隔开  
          return ' '.join(decoded_tokens)  
      

BPE算法的关键在于它通过迭代地合并最频繁的字符对来构建一个新的、更有效的词汇表。这种方法能够有效地处理不在初始词汇表中的词(OOV词),因为它可以动态地创建新的词汇单元来表示这些未知词。同时,BPE还可以帮助减少词汇表的大小,提高模型的泛化能力,并有助于处理形态丰富的语言。

  • 32
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值