一, BPE编码
(Byte Pair Encoding,简称 BPE)方法,BPE 是一种能够解决未登录词问题,并减小词典大小的方法。它综合利用了单词层面编码和字符层面编码的优势,
举例来说,我们要对下面的字符串编码,
aaabdaaabac
字节对 aa 出现的次数最多,所以我们将它替换成一个没在字符串中被用过的字符 Z ,
ZabdZabac
Z=aa
然后我们重复这个过程,用 Y 替换 ab ,
ZYdZYac
Y=ab
Z=aa
继续,用 X 替换 ZY ,
XdXac
X=ZY
Y=ab
Z=aa
这个过程重复进行,直到没有字节对出现超过一次。当需要解码时,就将上述替换过程反向进行。
下面是一段 BPE 算法原文中对 BPE 算法的实现:
import re
import collections
def get_stats(vocab):
pairs = collections.defaultdict(int)
for word, freq in vocab.items():
symbols = word.split()
for i in range(len(symbols)-1):
pairs[symbols[i], symbols[i+1]] += freq # 计算字节对出现频率
return pairs
def merge_vocab(pair, v_in):
v_out = {}
bigram = re.escape(' '.join(pair)) # 将字节对中可解释为正则运算符的字符转义
p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)') # 将要合并的字节对前后只能为空白字符
for word in v_in:
w_out = p.sub(''.join(pair), word) # 合并符合条件的字节对
v_out[w_out] = v_in[word]
return v_out
vocab = {'l o w </w>': 5, 'l o w e r </w>': 2,
'n e w e s t </w>': 6, 'w i d e s t </w>': 3}
num_merges = 10
for i in range(num_merges):
pairs = get_stats(vocab)
best = max(pairs, key=pairs.get) # 选择频率最大的字节对
vocab = merge_vocab(best, vocab)
print(best)