BPE原理及代码简单演示

15 篇文章 2 订阅
本文详细讲解了BPE算法,如何通过统计字符频率合并成新令牌,以及如何在编码和解码过程中应用。
摘要由CSDN通过智能技术生成

BPE原理及代码简单演示

简介

BPE(Byte pair encoding)是对字节编码,BPE 算法是在 UTF-8 编码的字符串上运行的,并且它是“字节级”的(按字节分割)。 理论上任何 UTF-8 编码的字符串都可以统一使用 BPE

算法过程

这个算法个人感觉很简单,下面就来讲解下:

比如我们想编码:

aaabdaaabac

我们会发现这里的aa出现的词数最高(我们这里只看两个字符的频率),那么用这里没有的字符Z来替代aa:
ZabdZabac
Z=aa

此时,又发现ab出现的频率最高,那么同样的,Y来代替ab:
ZYdZYac
Y=ab
Z=aa

同样的,ZY出现的频率大,我们用X来替代ZY:
XdXac
X=ZY
Y=ab
Z=aa

最后,连续两个字符的频率都为1了,也就结束了。就是这么简单。
解码的时候,就按照相反的顺序更新替换即可。

代码

两个工具函数

# 统计相邻元素词频
def get_stats(ids, counts=None):
    """
    Given a list of integers, return a dictionary of counts of consecutive pairs
    Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
    Optionally allows to update an existing dictionary of counts
    """
    counts = {} if counts is None else counts
    for pair in zip(ids, ids[1:]): # iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts

# ids 中出现的元素对(pair)使用 idx 替换
def merge(ids, pair, idx):
    """
    In the list of integers (ids), replace all consecutive occurrences
    of pair with the new integer token idx
    Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
    """
    newids = []
    i = 0
    while i < len(ids):
        # if not at the very last position AND the pair matches, replace it
        if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids

训练

# BPE训练
text = "aaabdaaabac" # 用于训练的文本语料
vocab_size = 256 + 3 # 词表大小 256 are the byte tokens, then do 3 merges

text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255

# iteratively merge the most common pairs to create new tokens
merges = {} # (int, int) -> int

vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
num_merges = vocab_size - 256

for i in range(num_merges):
    if len(ids) <= 1:
        break
    # count up the number of times every consecutive pair appears
    stats = get_stats(ids)
    # find the pair with the highest count
    pair = max(stats, key=stats.get)
    # mint a new token: assign it the next available id
    idx = 256 + i
    # replace all occurrences of pair in ids with idx
    ids = merge(ids, pair, idx)
    # save the merge
    merges[pair] = idx  # used in encode
    vocab[idx] = vocab[pair[0]] + vocab[pair[1]]   # used in decode

编码解码

# 编码
text_bytes = text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
while len(ids) >= 2:
    # find the pair with the lowest merge index
    stats = get_stats(ids)
    # 在stats中寻找最低idx对应的pair,并对其合并
    pair = min(stats, key=lambda p: merges.get(p, float("inf")))
    # subtle: if there are no more merges available, the key will
    # result in an inf for every single pair, and the min will be
    # just the first pair in the list, arbitrarily
    # we can detect this terminating case by a membership check
    if pair not in merges:
        break # nothing else can be merged anymore
    # otherwise let's merge the best pair (lowest merge index)
    idx = merges[pair]
    ids = merge(ids, pair, idx)
print(ids)

# 解码
# given ids (list of integers), return Python string
text_bytes = b"".join(vocab[idx] for idx in ids)
text = text_bytes.decode("utf-8", errors="replace")
print(text)

在这里插入图片描述

参考

Byte pair encoding
karpathy/minbpe
tokenizers
一分钟搞懂的算法之BPE算法

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值