简介
BPE(Byte pair encoding)是对字节编码,BPE 算法是在 UTF-8 编码的字符串上运行的,并且它是“字节级”的(按字节分割)。 理论上任何 UTF-8 编码的字符串都可以统一使用 BPE
算法过程
这个算法个人感觉很简单,下面就来讲解下:
比如我们想编码:
aaabdaaabac
我们会发现这里的aa出现的词数最高(我们这里只看两个字符的频率),那么用这里没有的字符Z来替代aa:
ZabdZabac
Z=aa
此时,又发现ab出现的频率最高,那么同样的,Y来代替ab:
ZYdZYac
Y=ab
Z=aa
同样的,ZY出现的频率大,我们用X来替代ZY:
XdXac
X=ZY
Y=ab
Z=aa
最后,连续两个字符的频率都为1了,也就结束了。就是这么简单。
解码的时候,就按照相反的顺序更新替换即可。
代码
两个工具函数
# 统计相邻元素词频
def get_stats(ids, counts=None):
"""
Given a list of integers, return a dictionary of counts of consecutive pairs
Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
Optionally allows to update an existing dictionary of counts
"""
counts = {} if counts is None else counts
for pair in zip(ids, ids[1:]): # iterate consecutive elements
counts[pair] = counts.get(pair, 0) + 1
return counts
# ids 中出现的元素对(pair)使用 idx 替换
def merge(ids, pair, idx):
"""
In the list of integers (ids), replace all consecutive occurrences
of pair with the new integer token idx
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
"""
newids = []
i = 0
while i < len(ids):
# if not at the very last position AND the pair matches, replace it
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
训练
# BPE训练
text = "aaabdaaabac" # 用于训练的文本语料
vocab_size = 256 + 3 # 词表大小 256 are the byte tokens, then do 3 merges
text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
# iteratively merge the most common pairs to create new tokens
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
num_merges = vocab_size - 256
for i in range(num_merges):
if len(ids) <= 1:
break
# count up the number of times every consecutive pair appears
stats = get_stats(ids)
# find the pair with the highest count
pair = max(stats, key=stats.get)
# mint a new token: assign it the next available id
idx = 256 + i
# replace all occurrences of pair in ids with idx
ids = merge(ids, pair, idx)
# save the merge
merges[pair] = idx # used in encode
vocab[idx] = vocab[pair[0]] + vocab[pair[1]] # used in decode
编码解码
# 编码
text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
while len(ids) >= 2:
# find the pair with the lowest merge index
stats = get_stats(ids)
# 在stats中寻找最低idx对应的pair,并对其合并
pair = min(stats, key=lambda p: merges.get(p, float("inf")))
# subtle: if there are no more merges available, the key will
# result in an inf for every single pair, and the min will be
# just the first pair in the list, arbitrarily
# we can detect this terminating case by a membership check
if pair not in merges:
break # nothing else can be merged anymore
# otherwise let's merge the best pair (lowest merge index)
idx = merges[pair]
ids = merge(ids, pair, idx)
print(ids)
# 解码
# given ids (list of integers), return Python string
text_bytes = b"".join(vocab[idx] for idx in ids)
text = text_bytes.decode("utf-8", errors="replace")
print(text)
参考
Byte pair encoding
karpathy/minbpe
tokenizers
一分钟搞懂的算法之BPE算法