生物大分子平台(9)
2021SC@SDUSC
文章目录
0 本周工作
本周继续学习transformer中的BPE分词方法,并且阅读完成apply_bpe方法。
1 BPE算法简介
参见上篇推文
2 APPLY_BPE代码分析
这部分的代码承接上部分学习BPE的代码,类似于数据挖掘中的序列挖掘得到的结果之后,对新输入的数据使用BPE进行编码。
2.1 导入引用库
这里新接触到的库是inspect库,inspect库也是python中一个常用的库。inspect是用来获取对象的信息,对象包括模块(往往是一个py文件)、类、方法、函数、报错追踪、帧对象和代码对象。例如,它能用来帮助你检验类的内容,检索一个方法的源代码,提取并格式化函数的参数列表,或者获取用来展示一个traceback的所有信息。
from __future__ import unicode_literals, division
import sys
import os
import inspect
import codecs
import io
import re
import warnings
import random
2.2 类初始化
将要编码的文本、merge、分割符号、已经学习到的词汇传入到此类中,然后检查版本信息,将形式参数传递给类内新定义的参数,保存下来,然后供后续方法使用。
def __init__(self, codes, merges=-1, separator='@@', vocab=None, glossaries=None):
codes.seek(0)
offset=1
firstline = codes.readline()
if firstline.startswith('#version:'):
self.version = tuple([int(x) for x in re.sub(r'(\.0+)*$','', firstline.split()[-1]).split(".")])
offset += 1
else:
self.version = (0, 1)
codes.seek(0)
self.bpe_codes = [tuple(item.strip('\r\n ').split(' ')) for (n, item) in enumerate(codes) if (n < merges or merges == -1)]
for i, item in enumerate(self.bpe_codes):
if len(item) != 2:
sys.stderr.write('Error: invalid line {0} in BPE codes file: {1}\n'.format(i+offset, ' '.join(item)))
sys.stderr.write('The line should exist of exactly two subword units, separated by whitespace\n')
sys.exit(1)
self.bpe_codes = dict([(code,i) for (i,code) in reversed(list(enumerate(self.bpe_codes)))])
self.bpe_codes_reverse = dict([(pair[0] + pair[1], pair) for pair,i in self.bpe_codes.items()])
self.separator = separator
self.vocab = vocab
self.glossaries = glossaries if glossaries else []
self.glossaries_regex = re.compile('^({})$'.format('|'.join(glossaries))) if glossaries else None
self.cache = {}
2.3 行处理
处理段中段线、空格、以及前导。即,将一段中的文本中的特殊字符删除掉再进行统计。
def process_line(self, line, dropout=0):
out = ""
leading_whitespace = len(line)-len(line.lstrip('\r\n '))
if leading_whitespace:
out += line[:leading_whitespace]
out += self.segment(line, dropout)
trailing_whitespace = len(line)-len(line.rstrip('\r\n '))
if trailing_whitespace and trailing_whitespace != len(line):
out += line[-trailing_whitespace:]
return out
2.4 分割
使用BPE编码分割单个输入的句子,丢失值设置为0。
并且使用BPE编码获得当前句子集合的token集合
def segment(self, sentence, dropout=0):
segments = self.segment_tokens(sentence.strip('\r\n ').split(' '), dropout)
return ' '.join(segments)
def segment_tokens(self, tokens, dropout=0):
output = []
for word in tokens:
# eliminate double spaces
if not word:
continue
new_word = [out for segment in self._isolate_glossaries(word)
for out in encode(segment,
self.bpe_codes,
self.bpe_codes_reverse,
self.vocab,
self.separator,
self.version,
self.cache,
self.glossaries_regex,
dropout)]
for item in new_word[:-1]:
output.append(item + self.separator)
output.append(new_word[-1])
return output
2.5 定义孤立词汇表
def _isolate_glossaries(self, word):
word_segments = [word]
for gloss in self.glossaries:
word_segments = [out_segments for segment in word_segments
for out_segments in isolate_glossary(segment, gloss)]
return word_segments
2.6 对孤立的词汇表进行合并操作
即对在词汇表中孤立的各个token使用从前到后的合并操作使得词汇表对整个文章做一个更加好的描述。
在进行合并之前,我们先检查version参数,确定一个句子的词头和词尾处理模式。
当句子长度大于1时,获取符号对列表,可以适当地使用dropout参数。
找到第一个合并位置,然后遍历出所有词的合并位置,如果合并在当前位置之前开始,则合并无效。如果有重叠对,就会发生这种情况:(x x x -> xx x)
然后合并对,迭代、循环进行合并。
def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache, glossaries_regex=None, dropout=0):
if not dropout and orig in cache:
return cache[orig]
if glossaries_regex and glossaries_regex.match(orig):
cache[orig] = (orig,)
return (orig,)
if len(orig) == 1:
return orig
if version == (0, 1):
word = list(orig) + ['</w>']
elif version == (0, 2):
word = list(orig[:-1]) + [orig[-1] + '</w>']
else:
raise NotImplementedError
while len(word) > 1:
pairs = [(bpe_codes[pair],i,pair) for (i,pair) in enumerate(zip(word, word[1:])) if (not dropout or random.random() > dropout) and pair in bpe_codes]
if not pairs:
break
bigram = min(pairs)[2]
positions = [i for (rank,i,pair) in pairs if pair == bigram]
i = 0
new_word = []
bigram = ''.join(bigram)
for j in positions:
if j < i:
continue
new_word.extend(word[i:j])
new_word.append(bigram)
i = j+2
new_word.extend(word[i:])
word = new_word
if word[-1] == '</w>':
word = word[:-1]
elif word[-1].endswith('</w>'):
word[-1] = word[-1][:-4]
word = tuple(word)
if vocab:
word = check_vocab_and_split(word, bpe_codes_reverse, vocab, separator)
cache[orig] = word
return word
2.7 反转BPE合并
递归地将段拆分为更小的单元,直到所有单元都在词汇表中,或者无法进一步拆分
def recursive_split(segment, bpe_codes, vocab, separator, final=False):
try:
if final:
left, right = bpe_codes[segment + '</w>']
right = right[:-4]
else:
left, right = bpe_codes[segment]
except:
#sys.stderr.write('cannot split {0} further.\n'.format(segment))
yield segment
return
if left + separator in vocab:
yield left
else:
for item in recursive_split(left, bpe_codes, vocab, separator, False):
yield item
if (final and right in vocab) or (not final and right + separator in vocab):
yield right
else:
for item in recursive_split(right, bpe_codes, vocab, separator, final):
yield item
检查词汇的子集是否在原先的文本集合中。检查 word 中的每个段是否在词汇表中,并通过反转 BPE 合并操作将 OOV 段分割成更小的单元
def check_vocab_and_split(orig, bpe_codes, vocab, separator):
out = []
for segment in orig[:-1]:
if segment + separator in vocab:
out.append(segment)
else:
#sys.stderr.write('OOV: {0}\n'.format(segment))
for item in recursive_split(segment, bpe_codes, vocab, separator, False):
out.append(item)
segment = orig[-1]
if segment in vocab:
out.append(segment)
else:
#sys.stderr.write('OOV: {0}\n'.format(segment))
for item in recursive_split(segment, bpe_codes, vocab, separator, True):
out.append(item)
return out
读取get_vocab.py生成的词汇文件,根据频率阈值进行过滤。
def read_vocabulary(vocab_file, threshold):
"""read vocabulary file produced by get_vocab.py, and filter according to frequency threshold.
"""
vocabulary = set()
for line in vocab_file:
word, freq = line.strip('\r\n ').split(' ')
freq = int(freq)
if threshold == None or freq >= threshold:
vocabulary.add(word)
return vocabulary
隔离出现在单词中的词汇表。返回子词列表。直到所有'glossary'词汇表都是孤立的。
def isolate_glossary(word, glossary):
if re.match('^'+glossary+'$', word) or not re.search(glossary, word):
return [word]
else:
segments = re.split(r'({})'.format(glossary), word)
segments, ending = segments[:-1], segments[-1]
segments = list(filter(None, segments)) # Remove empty strings in regex group.
return segments + [ending.strip('\r\n ')] if ending != '' else segments