import re
import collections
def get_vocab(filename):
vocab = collections.Counter()
with open(filename, "r", encoding="utf-8") as f:
for line in f:
words = line.strip().split()
for word in words:
vocab[" ".join(list(word)) + " </w>"] += 1 # 将每个单词分成字符,并在末尾加上</w>
return vocab
def get_stats(vocab):
pairs = collections.defaultdict(int)
for word, freq in vocab.items():
symbols = word.split()
for i in range(len(symbols) - 1):
pairs[symbols[i], symbols[i + 1]] += freq # 统计每对字符的频率
return pairs
def merge_vocab(pair, v_in):
v_out = {}
bigram = re.escape(" ".join(pair))
p = re.compile(r"(?<!\S)" + bigram + r"(?!\S)") # 匹配目标字符对
for word in v_in:
w_out = p.sub("".join(pair), word) # 将目标字符对替换为新字符
v_out[w_out] = v_in[word]
return v_out
vocab = get_vocab("test.txt") # test.txt是一个文本文件,包含若干单词
num_merges = 10 # 合并次数
for i in range(num_merges):
pairs = get_stats(vocab)
best = max(pairs, key=pairs.get) # 找出频率最高的字符对
vocab = merge_vocab(best, vocab)
print(f"Merge {i + 1}: {best}")
BPE编码Python实现
最新推荐文章于 2024-04-24 13:35:29 发布