基于BPE的汉语tokenization
基本环境
- 操作系统:Windows10
- python版本:3.7.8
算法说明
train.py
get_vocab——从训练语料中获取基本字典
def get_vocab(filename):
vocab = collections.defaultdict(int)
with open(filename, 'r', encoding="utf-8") as f:
# 按行切分
for line in f:
words = line.split('\n')
# 在词尾加上</w>存入字典
for word in words:
vocab[''.join(list(word.strip())) + ' </w>'] += 1
return vocab
get_tokens_from_vocab——从字典中获取基本词表
def get_tokens_from_vocab(vocab):
tokens_frequencies = collections.defaultdict(int)
for word, freq in vocab.items():
word_tokens = word.split()
for token in word_tokens:
tokens_frequencies[token] += freq
return tokens_frequencies
get_stats——从词典中获取pairs的频次
def get_stats(vocab):
pairs = collections.defaultdict(int)
for word, freq in vocab.items():
symbols = word.split()
# 获取pair频次
for i in range(len(symbols) - 1):
pairs[symbols[i], symbols[i + 1]] += freq
return pairs
merge_vocab——合并pair并替换原词中的pair
def merge_vocab(pair, v_in):
v_out = {}
bigram = re.escape(' '.join(pair))
p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
for word in v_in:
w_out = p.sub(''.join(pair), word)
v_out[w_out] = v_in[word]
return v_out
train——获取最终词表并保存
def train(filename, k):
vocab = get_vocab(filename)
# 获取初始词表
tokens_frequencies = get_tokens_from_vocab(vocab)
# 词表规模达到k时停止
while len(tokens_frequencies) < k:
pairs = get_stats(vocab)
best = max(pairs, key=pairs.get)
# 向词表中加入频次最大的pair
vocab = merge_vocab(best, vocab)
tokens_frequencies[''.join(best)] += 1
# 将词表排序
sorted_tokens_tuple = sorted(
tokens_frequencies.items(),
key=lambda item: (measure_token_length(item[0]), item[1]),
reverse=True
)
F = open('sorted_tokens','wb')
pickle.dump(sorted_tokens_tuple, F)
F.close()
tokenlize.py
tokenize_word——进行分词
def tokenize_word(string, sorted_tokens, unknown_token='</u>'):
if string == '':
return []
if sorted_tokens == []:
return [unknown_token]
string_tokens = []
for i in range(len(sorted_tokens)):
token = sorted_tokens[i]
token_reg = re.escape(token)
# 获取匹配起止位置
matched_positions = [(m.start(0), m.end(0)) for m in re.finditer(token_reg, string)]
if len(matched_positions) == 0:
continue
# 获取子串结束位置
substring_end_positions = [matched_position[0] for matched_position in matched_positions]
substring_start_position = 0
for substring_end_position in substring_end_positions:
substring = string[substring_start_position:substring_end_position]
# 对子串进行分词
string_tokens += tokenize_word(substring, sorted_tokens[i + 1 : ], unknown_token)
string_tokens += [token]
# 更新子串起始位置
substring_start_position = substring_end_position + len(token)
# 获取仍需处理子串
remaining_substring = string[substring_start_position:]
# 对仍需处理的子串进行分词
string_tokens += tokenize_word(remaining_substring, sorted_tokens[i + 1 : ], unknown_token)
break
return string_tokens
bpe——调用分词并输出到txt文件中
def bpe(filename, tokelizedfile, modelname):
model = open(modelname, 'rb')
E = pickle.load(model)
sorted_tokens= [token for (token, freq) in E]
with open(filename, 'r', encoding = "utf-8") as f:
text = ''
for line in f:
text = text + ''.join(line.split(' '))
text_tokelized = tokenize_word(str(text), sorted_tokens)
# 以空格分开
text_tokelized = ' '.join(text_tokelized)
with open(tokelizedfile, "w") as f:
f.write(text_tokelized)
源码
train.py
'''
Description: train
Autor: LarFii
LastEditTime: 2021-03-25 01:34:48
'''
import os
import re
import sys
import pickle
import collections
from tqdm import tqdm
def get_vocab(filename):
vocab = collections.defaultdict(int)
with open(filename, 'r', encoding="utf-8") as f:
for line in f:
words = line.split('\n')
for word in words:
vocab[''.join(list(word.strip())) + ' </w>'] += 1
return vocab
def get_tokens_from_vocab(vocab):
tokens_frequencies = collections.defaultdict(int)
for word, freq in vocab.items():
word_tokens = word.split()
for token in word_tokens:
tokens_frequencies[token] += freq
return tokens_frequencies
def get_stats(vocab):
pairs = collections.defaultdict(int)
for word, freq in vocab.items():
symbols = word.split()
for i in range(len(symbols)-1):
pairs[symbols[i],symbols[i+1]] += freq
return pairs
def merge_vocab(pair, v_in):
v_out = {}
bigram = re.escape(' '.join(pair))
p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
for word in v_in:
w_out = p.sub(''.join(pair), word)
v_out[w_out] = v_in[word]
return v_out
def measure_token_length(token):
if token[-4:] == '</w>':
return len(token[:-4]) + 1
else:
return len(token)
def train(filename, k):
vocab = get_vocab(filename)
tokens_frequencies = get_tokens_from_vocab(vocab)
while len(tokens_frequencies) < k:
pairs = get_stats(vocab)
best = max(pairs, key=pairs.get)
vocab = merge_vocab(best, vocab)
tokens_frequencies[''.join(best)] += 1
sorted_tokens_tuple = sorted(
tokens_frequencies.items(),
key=lambda item: (measure_token_length(item[0]), item[1]),
reverse=True
)
F = open('sorted_tokens','wb')
pickle.dump(sorted_tokens_tuple, F)
F.close()
if __name__ == '__main__':
train('train_BPE.txt', 10000)
tokenlize.py
'''
Description: tokenlize
Autor: LarFii
LastEditTime: 2021-03-25 01:30:28
'''
import re
import sys
import pickle
import collections
from tqdm import tqdm
def tokenize_word(string, sorted_tokens, unknown_token='</u>'):
if string == '':
return []
if sorted_tokens == []:
return [unknown_token]
string_tokens = []
for i in range(len(sorted_tokens)):
token = sorted_tokens[i]
token_reg = re.escape(token)
matched_positions = [(m.start(0), m.end(0)) for m in re.finditer(token_reg, string)]
if len(matched_positions) == 0:
continue
substring_end_positions = [matched_position[0] for matched_position in matched_positions]
substring_start_position = 0
for substring_end_position in substring_end_positions:
substring = string[substring_start_position:substring_end_position]
string_tokens += tokenize_word(substring, sorted_tokens[i + 1 : ], unknown_token)
string_tokens += [token]
substring_start_position = substring_end_position + len(token)
remaining_substring = string[substring_start_position:]
string_tokens += tokenize_word(remaining_substring, sorted_tokens[i + 1 : ], unknown_token)
break
return string_tokens
def bpe(filename, tokelizedfile, modelname):
model = open(modelname, 'rb')
E = pickle.load(model)
sorted_tokens= [token for (token, freq) in E]
with open(filename, 'r', encoding = "utf-8") as f:
text = ''
for line in f:
text = text + ''.join(line.split(' '))
text_tokelized = tokenize_word(str(text), sorted_tokens)
text_tokelized = ' '.join(text_tokelized)
with open(tokelizedfile, "w") as f:
f.write(text_tokelized)
if __name__ =='__main__':
bpe('test_BPE.txt', "text_tokelized.txt", 'tokelize.pkl')