前言
tokenization是NLP和当今大模型中很重要但又十分枯燥的一部分,理解他对后续的一些任务优化和debug还是比较重要的,因此会开两篇来写。
古早的分词方法有基于词粒度和字粒度的,字粒度比较简单,需要构建的词表也很小,但由于切分太细导致语义信息不明确等问题,所以不太常用。词粒度则是需要构建一个词典,来对输入的字符串进行匹配,字词的搜索/匹配算法和词表的构建是关键。这些方面b乎有个博主说的很详细,这里直接贴链接:NLP中的Tokenization - 知乎 (zhihu.com)
目前常用的分词算法是基于子词(subword)的算法,如GPT2的BPE、Bert的WordPiece等。
基于utf-8的BPE
BPE全称Byte Pair Encoding,字节对编码,是一种数据压缩方法。最早是论文《 Neural Machine Translation of Rare Words with Subword Units》将其引入到NLP技术中。GPT2中,会先将文本编码成utf-8格式,得到一堆int索引,即tokens,再进行bpe编码。
这里先参考Karpathy的代码进行简单的实现,也可以直接去看大神的课程 -> Let's build the GPT Tokenizer (youtube.com) p10 Let's reproduce GPT-2 (124M)_哔哩哔哩_bilibili
训练流程
给定一段训练文本,遍历文本,获取索引对及其出现的次数,将出现次数最多的组合作为新的索引,加入merge_list中,然后将原文本中相应的两个tokens合并,形成新的tokens_list以待下一次遍历。如此循环,最终可以获得一个大的merge_list和最终合并后的tokens。这里我们处理text之前,先将其转换成utf-8的索引,再进行合并,代码实现如下:
# 获取token pair及其对应的出现次数
def get_stats(ids):
counts = {}
for pair in zip(ids, ids[1:]):
counts[pair] = counts.get(pair, 0) + 1
return counts
def merge(ids, pair, idx):
newids = []
i = 0
while i < len(ids):
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
# 把pair换成idx
else:
newids.append(ids[i])
i += 1
return newids
# ---
tokens = "training texts".encode('utf-8')
# 这里是个toy model,实际使用的文本会更长以保证丰富性
vocab_size = 276 # the desired final vocabulary size
num_merges = vocab_size - 256 # 因为utf-8默认256个
ids = list(tokens) # copy so we don't destroy the original list
print(len(ids))
merges = {} # (int, int) -> int
for i in range(num_merges):
stats = get_stats(ids)
pair = max(stats, key=stats.get) #每次获得出现频率最高的pair
idx = 256 + i
print(f"merging {pair} into a new token {idx}")
ids = merge(ids, pair, idx)
merges[pair] = idx
可以看到我们指定了vocab_size,即最终生成的词表大小,来确定我们需要多少轮次的合并。GPT2中的词表大小为50257,意味着他进行了50000轮的合并,之所以是257,是因为他还加了个special token,|<endoftext>|
编解码推理(给定merge_list)
实际应用中,merge_list需要经过训练得到,训练流程和上面的大体一致,最终会用于新的tokens的合并。在应用merge_list做推理时,和训练流程中的tokens合并有所不同,每次循环根据tokens获取token pairs,再从中寻找merge_list中存在的pair进行返回,注意这里不需要像训练中返回出现次数最多的pair,只需要返回merge_list中存在的,并对原来的tokens进行合并和替换,一直迭代知道最终的tokens再也没有merge_list中存在的pair,即为合并结束。
代码实现如下,同样,我们会在处理text前先转换成utf-8:
def encode(text):
# given a string, return list of integers (the tokens)
tokens = list(text.encode("utf-8"))
while len(tokens) >= 2:
stats = get_stats(tokens) # 见上面的实现
# 从stats中找到一个pair,这个pair使得key中返回的值最小。
# key的机制是:这个pair在merges里则返回对应的值,否则返回inf,因此只要不是inf,每次返回的都是合法的pair。
pair = min(stats, key=lambda p: merges.get(p, float("inf")))
# 这里额外确认一次pair是否合法,规避返回inf的情况
if pair not in merges:
break # nothing else can be merged
idx = merges[pair]
tokens = merge(tokens, pair, idx)
return tokens
print(encode("hello world"))
GPT2的tokenizer(推理)
GPT2提供了他们的源码,其中包括tokenizer,地址如下:gpt-2/src/encoder.py at master · openai/gpt-2 · GitHub
这里我直接将里面的一些内容复制出来,然后写成一个新的脚本用于学习参考。
正则表达式
gpt2中的编码方式也是BPE,但是略微有所不同,主要在于正则表达式的使用和byte_encoder的引入。在处理输入文本时,我们之前是直接将其编码成utf-8,然后进行bpe编码,gpt2中则是先将文本进行分词,然后对每个分隔的词组分别进行后续的byte-encode和bpe编码,保证了更加清晰的语义边界。
import regex as re
gpt2pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
text="hello world"
tokens = re.findall(gpt2pat, text)
Byte-Encoder
byte_encoder的设计其实没太理解,他这里构造了一个字典(字典的大小其实是可以自定义的),先将text编码成utf-8的token,然后再根据字典来将这些token映射到某个字符。然后才将这些字符送到bpe进行编码。
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
byte_encoder = bytes_to_unicode()
BPE in GPT2
bpe的流程和上面的大体一致,这里的encoder.json相当于vocab,即词表,根据合并后的词组映射到索引;vocab.bpe是词组对的集合,相当于之前的merge_list,根据这个文件我们构建一个{ 词组对:索引 }的字典,用来进行扫描。这两个文件可以通过wget或者浏览器下载。
# !wget https://openaipublic.blob.core.windows.net/gpt-2/models/1558M/vocab.bpe
# !wget https://openaipublic.blob.core.windows.net/gpt-2/models/1558M/encoder.json
此外,openai做了一些有意思的处理,加快了处理速度。比如说:
1. cache,因为gpt是先将文本分成一大堆tokens,并且单独处理,因此用cache记录处理过的tokens及其对应的词,可以避免重复的bpe运算。
2. word.index(first,i),直接跳到当前返回的pair的第一个字出现的位置,避免了逐个字扫描,大大加快了扫描速度。
import os, json
# !wget https://openaipublic.blob.core.windows.net/gpt-2/models/1558M/vocab.bpe
# !wget https://openaipublic.blob.core.windows.net/gpt-2/models/1558M/encoder.json
with open('encoder.json', 'r') as f:
encoder = json.load(f) # <--- ~equivalent to our "vocab"
with open('vocab.bpe', 'r', encoding="utf-8") as f:
bpe_data = f.read()
"""Byte pair encoding utilities"""
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
cache = {}
def bpe(token):
if token in cache:
return cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key = lambda pair: bpe_ranks.get(pair, float('inf')))
if bigram not in bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
# try这一段是用来加快扫描速度的,直接跳到可能有这个词组的地方
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
cache[token] = word
return word
流程总结
编码流程可以总结成:
正则表达式分词 -> 编码成utf-8 -> byte_encoder转换成对应的字符形成序列 -> bpe迭代合并字符 -> 根据encoder.json (即vocab) 转换成对应的索引(tokens)
解码流程则比较简单,总结如下:
tokens -> 反转键值对的vocab即为decoder,取出token对应的词 -> 反转键值对的byte_encoder即为byte_decoder,将每个字符转回utf-8索引 -> utf8解码 -> text
完整代码如下:
import os, json
import regex as re
from functools import lru_cache
# !wget https://openaipublic.blob.core.windows.net/gpt-2/models/1558M/vocab.bpe
# !wget https://openaipublic.blob.core.windows.net/gpt-2/models/1558M/encoder.json
with open('encoder.json', 'r') as f:
encoder = json.load(f) # <--- ~equivalent to our "vocab"
with open('vocab.bpe', 'r', encoding="utf-8") as f:
bpe_data = f.read()
"""Byte pair encoding utilities"""
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
cache = {}
def bpe(token):
if token in cache:
return cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key = lambda pair: bpe_ranks.get(pair, float('inf')))
if bigram not in bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
# try这一段是用来加快扫描速度的,直接跳到可能有这个词组的地方
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
cache[token] = word
return word
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
byte_encoder = bytes_to_unicode()
gpt2pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
# ^---- ~equivalent to our "merges"
def encode(text):
bpe_tokens = []
print(f"split tokens: {re.findall(gpt2pat, text)}\n")
for token in re.findall(gpt2pat, text):
print(f"token encode utf-8: {list(token.encode('utf-8'))}")
token = ''.join(byte_encoder[b] for b in token.encode('utf-8'))
print(f"byte token:{list(token)}")
bpe_token = bpe(token)
print(f"bpe token: {bpe_token}\n")
bpe_tokens.extend(encoder[bpe_token] for bpe_token in bpe(token).split(' '))
return bpe_tokens
def decode(tokens):
decoder = {v:k for k,v in encoder.items()}
byte_decoder = {v:k for k, v in byte_encoder.items()}
text = ''.join(decoder[token] for token in tokens)
text = bytearray([byte_decoder[c] for c in text]).decode('utf-8')
return text
text = "elhllo world"
encoded_tokens = encode(text)
print(f"encode: {encoded_tokens}")
print(f"decode: {decode(encoded_tokens)}")