Byte Pair Encoding

Introduction

In natural language processing models, the inputs to the model are usually sequences of sentences, such as “I went to New York last week.”. The sequence consists of tokens. In old language models, tokens are usually white-space separated words and punctuations, such as [“i”, “went”, “to”, “new”, “york”, “last”, “week”, “.”]. I remember the word2vec model was using this kind of tokenization method. However, this has some draw backs. For example, if the model learned the relationship between “old”, “older”, and “oldest”, it does not tell the model anything about the relationship between “smart”, “smarter”, and “smartest”. However, in this case, if we use some subtokens such as “er” and “est”, and the model learned the relationship between “old”, “older”, and “oldest”, it will tell the model some information about the relationship between “smart”, “smarter”, and “smartest”. Now that using subtokens sounds a better tokenization method, how do we engineer these subtokens? Preparing subtokens by ourselves is heuristics and label consuming, thus is not desirable. But information theory comes and saves us again.

BPE Algorithm

Token Learning from Dataset

《Neural Machine Translation of Rare Words with Subword Units》文章的作者开源了对语料库中的单词进行字节对编码encoding的源码。首先我们统计每一个单词在语料库corpus中出现的频率,对于每一个单词,我们在单词的末尾添加一个特殊的“stop token"——</w>,然后将之分割成多个字符(characters)。最初,单词的token是所有的字符和最后的stop token按序构成,例如,单词"low"对应的token就是[“l”, “o”, “w”, “</w>”] ,所以我们统计完数据集中的所有单词后,按照上述操作就可以得到一个由“token化”的单词及其对应出现的次数(代表着出现的频率)所组成的词汇表vocabulary,比如:

{'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}

之后,我们要开始迭代(iteration),迭代就是要统计所有连续字节对(consecutive byte pair的出现频率并合并出现最多的那一组连续字节对成为一个token,并且merge完成后要消除原来参与merge的token。


举个例子,假设从上面开始第一轮迭代merge,那么进行merge的时候,因为字节对e和s出现了6+3=9次,是最多的,所以我们将e和s合并成es: ```yaml {'l o w ': 5, 'l o w e r ': 2, 'n e w es t ': 6, 'w i d es t ': 3} ``` 以上经过merge后,词汇表中原来的e和s两个token就不复存在了。 然后开始第二轮merge的迭代,根据第一轮迭代的结果,发现token"es"和"t"出现次数最多:6+3=9,所以我们再次merge出一个新token:"est",此后词汇表中也就不存在"t"和"es"这两种token了。

最后,我们通过统计迭代数量或token的最大数量来设置限制我们需要的token数量。


stop token 也是十分重要的,没有< /w>的情况,比如一个"st"的token,这个token可以来自两个截然不同的单词(如"star" 和 "widest"),此时就不容易判断这是谁的token;但是如果加上< /w>,当此轮迭代发现了< /w>模型就知道"st"这个token来自“wide st < /w>”而不是“st ar < /w>”。
每一次merge后,**tokens**的数量趋势将会:
  • 递减1个
  • 递增1个
  • 保持不变

注意,tokens指的是由两个token(或上一轮的tokens)merge得到的,不同于单数的token
但是在实际工程中,往往tokens的数量增加的次数多一些。

Token learning Example

作者修改了paper《 “Neural Machine Translation of Rare Words with Subword Units”》的code,并将之作为BPE在一个实际数据集应用的一个例程。

import re, collections

def get_vocab(filename):
    vocab = collections.defaultdict(int)
    with open(filename, 'r', encoding='utf-8') as fhand:
        for line in fhand:
            words = line.strip().split()
            for word in words:
                vocab[' '.join(list(word)) + ' </w>'] += 1
    return vocab

def get_stats(vocab):
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    return pairs

def merge_vocab(pair, v_in):
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

def get_tokens(vocab):
    tokens = collections.defaultdict(int)
    for word, freq in vocab.items():
        word_tokens = word.split()
        for token in word_tokens:
            tokens[token] += freq
    return tokens

# vocab = {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}

# Get free book from Gutenberg
# wget http://www.gutenberg.org/cache/epub/16457/pg16457.txt
vocab = get_vocab('pg16457.txt')

print('==========')
print('Tokens Before BPE')
tokens = get_tokens(vocab)
print('Tokens: {}'.format(tokens))
print('Number of tokens: {}'.format(len(tokens)))
print('==========')

num_merges = 1000
for i in range(num_merges):
    pairs = get_stats(vocab)
    if not pairs:
        break
    best = max(pairs, key=pairs.get)
    vocab = merge_vocab(best, vocab)
    print('Iter: {}'.format(i))
    print('Best pair: {}'.format(best))
    tokens = get_tokens(vocab)
    print('Tokens: {}'.format(tokens))
    print('Number of tokens: {}'.format(len(tokens)))
    print('==========')

正确运行后期望输出如下:

==========
Tokens Before BPE
Tokens: defaultdict(<class 'int'>, {'\ufeff': 1, 'T': 1610, 'h': 26094, 'e': 59152, '</w>': 101830, 'P': 780, 'r': 29540, 'o': 34983, 'j': 857, 'c': 13891, 't': 44258, 'G': 300, 'u': 13731, 'n': 32499, 'b': 7428, 'g': 8744, 'E': 901, 'B': 1163, 'k': 2726, 'f': 10469, 'A': 1381, 'l': 20632, 'd': 17576, 'M': 1206, ',': 8068, 'y': 8812, 'J': 80, 's': 28320, 'V': 104, 'i': 31435, 'a': 36692, 'w': 8133, 'm': 9812, 'v': 4880, '.': 4055, 'Y': 250, 'p': 8040, '-': 1128, 'L': 429, ':': 209, 'R': 369, 'D': 327, '6': 77, '2': 158, '0': 401, '5': 131, '[': 32, '#': 1, '1': 295, '4': 104, '7': 65, ']': 32, '*': 44, 'S': 860, 'O': 510, 'F': 422, 'H': 689, 'I': 1432, 'C': 863, 'U': 170, 'N': 796, 'K': 42, '/': 52, '"': 4086, '!': 1214, 'W': 579, '3': 105, "'": 1243, 'Q': 33, 'X': 49, 'Z': 10, '?': 651, '8': 75, '9': 38, '_': 1426, 'à': 3, 'x': 937, 'z': 365, '°': 41, 'q': 575, ';': 561, '(': 56, ')': 56, '{': 23, '}': 16, 'è': 2, 'é': 14, '+': 2, '=': 3, 'ö': 2, 'ê': 5, 'â': 1, 'ô': 1, 'Æ': 3, 'æ': 2, '%': 1, '@': 2, '$': 2})
Number of tokens: 98
==========
Iter: 0
Best pair: ('e', '</w>')
Tokens: defaultdict(<class 'int'>, {'\ufeff': 1, 'T': 1610, 'h': 26094, 'e</w>': 17749, 'P': 780, 'r': 29540, 'o': 34983, 'j': 857, 'e': 41403, 'c': 13891, 't': 44258, '</w>': 84081, 'G': 300, 'u': 13731, 'n': 32499, 'b': 7428, 'g': 8744, 'E': 901, 'B': 1163, 'k': 2726, 'f': 10469, 'A': 1381, 'l': 20632, 'd': 17576, 'M': 1206, ',': 8068, 'y': 8812, 'J': 80, 's': 28320, 'V': 104, 'i': 31435, 'a': 36692, 'w': 8133, 'm': 9812, 'v': 4880, '.': 4055, 'Y': 250, 'p': 8040, '-': 1128, 'L': 429, ':': 209, 'R': 369, 'D': 327, '6': 77, '2': 158, '0': 401, '5': 131, '[': 32, '#': 1, '1': 295, '4': 104, '7': 65, ']': 32, '*': 44, 'S': 860, 'O': 510, 'F': 422, 'H': 689, 'I': 1432, 'C': 863, 'U': 170, 'N': 796, 'K': 42, '/': 52, '"': 4086, '!': 1214, 'W': 579, '3': 105, "'": 1243, 'Q': 33, 'X': 49, 'Z': 10, '?': 651, '8': 75, '9': 38, '_': 1426, 'à': 3, 'x': 937, 'z': 365, '°': 41, 'q': 575, ';': 561, '(': 56, ')': 56, '{': 23, '}': 16, 'è': 2, 'é': 14, '+': 2, '=': 3, 'ö': 2, 'ê': 5, 'â': 1, 'ô': 1, 'Æ': 3, 'æ': 2, '%': 1, '@': 2, '$': 2})
Number of tokens: 99
==========
Iter: 1
Best pair: ('t', 'h')
Tokens: defaultdict(<class 'int'>, {'\ufeff': 1, 'T': 1610, 'h': 12065, 'e</w>': 17749, 'P': 780, 'r': 29540, 'o': 34983, 'j': 857, 'e': 41403, 'c': 13891, 't': 30229, '</w>': 84081, 'G': 300, 'u': 13731, 'n': 32499, 'b': 7428, 'g': 8744, 'E': 901, 'B': 1163, 'k': 2726, 'f': 10469, 'A': 1381, 'l': 20632, 'd': 17576, 'th': 14029, 'M': 1206, ',': 8068, 'y': 8812, 'J': 80, 's': 28320, 'V': 104, 'i': 31435, 'a': 36692, 'w': 8133, 'm': 9812, 'v': 4880, '.': 4055, 'Y': 250, 'p': 8040, '-': 1128, 'L': 429, ':': 209, 'R': 369, 'D': 327, '6': 77, '2': 158, '0': 401, '5': 131, '[': 32, '#': 1, '1': 295, '4': 104, '7': 65, ']': 32, '*': 44, 'S': 860, 'O': 510, 'F': 422, 'H': 689, 'I': 1432, 'C': 863, 'U': 170, 'N': 796, 'K': 42, '/': 52, '"': 4086, '!': 1214, 'W': 579, '3': 105, "'": 1243, 'Q': 33, 'X': 49, 'Z': 10, '?': 651, '8': 75, '9': 38, '_': 1426, 'à': 3, 'x': 937, 'z': 365, '°': 41, 'q': 575, ';': 561, '(': 56, ')': 56, '{': 23, '}': 16, 'è': 2, 'é': 14, '+': 2, '=': 3, 'ö': 2, 'ê': 5, 'â': 1, 'ô': 1, 'Æ': 3, 'æ': 2, '%': 1, '@': 2, '$': 2})
Number of tokens: 100
==========
Iter: 2
Best pair: ('t', '</w>')
Tokens: defaultdict(<class 'int'>, {'\ufeff': 1, 'T': 1610, 'h': 12065, 'e</w>': 17749, 'P': 780, 'r': 29540, 'o': 34983, 'j': 857, 'e': 41403, 'c': 13891, 't</w>': 9271, 'G': 300, 'u': 13731, 't': 20958, 'n': 32499, 'b': 7428, 'g': 8744, '</w>': 74810, 'E': 901, 'B': 1163, 'k': 2726, 'f': 10469, 'A': 1381, 'l': 20632, 'd': 17576, 'th': 14029, 'M': 1206, ',': 8068, 'y': 8812, 'J': 80, 's': 28320, 'V': 104, 'i': 31435, 'a': 36692, 'w': 8133, 'm': 9812, 'v': 4880, '.': 4055, 'Y': 250, 'p': 8040, '-': 1128, 'L': 429, ':': 209, 'R': 369, 'D': 327, '6': 77, '2': 158, '0': 401, '5': 131, '[': 32, '#': 1, '1': 295, '4': 104, '7': 65, ']': 32, '*': 44, 'S': 860, 'O': 510, 'F': 422, 'H': 689, 'I': 1432, 'C': 863, 'U': 170, 'N': 796, 'K': 42, '/': 52, '"': 4086, '!': 1214, 'W': 579, '3': 105, "'": 1243, 'Q': 33, 'X': 49, 'Z': 10, '?': 651, '8': 75, '9': 38, '_': 1426, 'à': 3, 'x': 937, 'z': 365, '°': 41, 'q': 575, ';': 561, '(': 56, ')': 56, '{': 23, '}': 16, 'è': 2, 'é': 14, '+': 2, '=': 3, 'ö': 2, 'ê': 5, 'â': 1, 'ô': 1, 'Æ': 3, 'æ': 2, '%': 1, '@': 2, '$': 2})
Number of tokens: 101
==========

Encoding and Decoding

解码非常简单,我们只需要将所有的tokens拼接(concatenate)一起,就可以得到原始的全部单词。例如,如果编码后的sequence是

[“the</w>, “high”, “est</w>, “moun”, “tain</w>]

那么我们可立知解码后的sequence是

“the</w> highest</w> mountain</w>

至于编码encoding,比如一个单词sequence:

[“the</w>, “highest</w>, “mountain</w>]

我们有一个从长token到短token排序的token(s)列表。对于每一个单词,我们通过遍历所有的tokens来确定是否这个列表中的这个token是这个单词的一个子字符,如果是,那么这个token就是这个单词的一个tokens。
比如:假设一个token的列表是:

[“errrr</w>, “tain</w>, “moun”, “est</w>, “high”, “the</w>, “a</w>]

我们就从列表最长的token——errrr< /w>,遍历到最短的token——a< /w>,看看有没有能够将每一个单词的子字符替换成对应的tokens。
如果遍历结束后仍然有子字符但是所有的列表中的token都被遍历,我们用unknown token代替这个subword(残留的子字符)。
对于上面这个token列表和上上面这个编码后的sequence,我们把

  • 单词中的 “the</ w>” token化 为 [“the</ w>”]
  • 单词中的"highest</ w>"token化 为 [“high”, “est</ w>”] (这中间merge了这两个token为一个tokens)
  • 单词中的"mountain</ w>"token化 为[“moun”, “tain</ w>”](同上)

实际上编码对计算力要求很高,一般我们可以预先token所有单词,并在字典中保存单词的token方式。但是若编码中出现了字典中没有出现过的单词,就只能按照上述的编码进行操作。下面是一个利用数据集学习已知和未知单词的编码例程,merges的次数设置的较大:

import re, collections

def get_vocab(filename):
    vocab = collections.defaultdict(int)
    with open(filename, 'r', encoding='utf-8') as fhand:
        for line in fhand:
            words = line.strip().split()
            for word in words:
                vocab[' '.join(list(word)) + ' </w>'] += 1

    return vocab

def get_stats(vocab):
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    return pairs

def merge_vocab(pair, v_in):
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

def get_tokens_from_vocab(vocab):
    tokens_frequencies = collections.defaultdict(int)
    vocab_tokenization = {}
    for word, freq in vocab.items():
        word_tokens = word.split()
        for token in word_tokens:
            tokens_frequencies[token] += freq
        vocab_tokenization[''.join(word_tokens)] = word_tokens
    return tokens_frequencies, vocab_tokenization

def measure_token_length(token):
    if token[-4:] == '</w>':
        return len(token[:-4]) + 1
    else:
        return len(token)

def tokenize_word(string, sorted_tokens, unknown_token='</u>'):
    
    if string == '':
        return []
    if sorted_tokens == []:
        return [unknown_token]

    string_tokens = []
    for i in range(len(sorted_tokens)):
        token = sorted_tokens[i]
        token_reg = re.escape(token.replace('.', '[.]'))

        matched_positions = [(m.start(0), m.end(0)) for m in re.finditer(token_reg, string)]
        if len(matched_positions) == 0:
            continue
        substring_end_positions = [matched_position[0] for matched_position in matched_positions]

        substring_start_position = 0
        for substring_end_position in substring_end_positions:
            substring = string[substring_start_position:substring_end_position]
            string_tokens += tokenize_word(string=substring, sorted_tokens=sorted_tokens[i+1:], unknown_token=unknown_token)
            string_tokens += [token]
            substring_start_position = substring_end_position + len(token)
        remaining_substring = string[substring_start_position:]
        string_tokens += tokenize_word(string=remaining_substring, sorted_tokens=sorted_tokens[i+1:], unknown_token=unknown_token)
        break
    return string_tokens

# vocab = {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}

vocab = get_vocab('pg16457.txt')

print('==========')
print('Tokens Before BPE')
tokens_frequencies, vocab_tokenization = get_tokens_from_vocab(vocab)
print('All tokens: {}'.format(tokens_frequencies.keys()))
print('Number of tokens: {}'.format(len(tokens_frequencies.keys())))
print('==========')

num_merges = 10000
for i in range(num_merges):
    pairs = get_stats(vocab)
    if not pairs:
        break
    best = max(pairs, key=pairs.get)
    vocab = merge_vocab(best, vocab)
    print('Iter: {}'.format(i))
    print('Best pair: {}'.format(best))
    tokens_frequencies, vocab_tokenization = get_tokens_from_vocab(vocab)
    print('All tokens: {}'.format(tokens_frequencies.keys()))
    print('Number of tokens: {}'.format(len(tokens_frequencies.keys())))
    print('==========')

# Let's check how tokenization will be for a known word
word_given_known = 'mountains</w>'
word_given_unknown = 'Ilikeeatingapples!</w>'

sorted_tokens_tuple = sorted(tokens_frequencies.items(), key=lambda item: (measure_token_length(item[0]), item[1]), reverse=True)
sorted_tokens = [token for (token, freq) in sorted_tokens_tuple]

print(sorted_tokens)

word_given = word_given_known 

print('Tokenizing word: {}...'.format(word_given))
if word_given in vocab_tokenization:
    print('Tokenization of the known word:')
    print(vocab_tokenization[word_given])
    print('Tokenization treating the known word as unknown:')
    print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token='</u>'))
else:
    print('Tokenizating of the unknown word:')
    print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token='</u>'))

word_given = word_given_unknown 

print('Tokenizing word: {}...'.format(word_given))
if word_given in vocab_tokenization:
    print('Tokenization of the known word:')
    print(vocab_tokenization[word_given])
    print('Tokenization treating the known word as unknown:')
    print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token='</u>'))
else:
    print('Tokenizating of the unknown word:')
    print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token='</u>')

The known word “mountains</ w>” was tokenized as “mountains</ w>” using the comprehensive encoding method described above. It did also match the learned tokenization of the known words saved in the dictionary. The unknown invented word “Ilikeeatingapples!</ w>” was also tokenized as [‘I’, ‘like’, ‘ea’, ‘ting’, ‘app’, ‘l’, ‘es!</ w>’].

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值