BK树或者称为Burkhard-Keller树,是一种基于树的数据结构。用于快速查找近似字符串匹配,比方说拼写纠错,或模糊查找,当搜索”aeek”时能返回与其最相似的字符串”seek”和”peek”。
在构建BK树之前,我们需要定义一种用于比较字符串相似度的度量方法。通常都是采用编辑距离(Levenshtein Distance),这是一种用于表示两个字符串相互转换需要进行的最少编辑步数。
在确定度量方法后,可以构建出基于该比较方法的度量空间,该空间具有以下3种特性:
假设存在字符串A、B、C,d(A,B)表示两字符串的编辑距离
1.如果d(A,B)=0,那么表示A,B字符串相等
2.d(A,B)与d(B,A)相等
3.d(A,C)>= d(A,B)+d(B,C)
最后一条又叫做三角不等式,表示A与C的编辑距离一定大于A变为B后再变为C的编辑距离和。
BK建树
首先我们随便找一个单词作为根(比如GAME)。以后插入一个单词时首先计算单词与根的Levenshtein距离:如果这个距离值是该节点处头一次出现,建立一个新的儿子节点;否则沿着对应的边递归下去。例如,我们插入单词FAME,它与GAME的距离为1,于是新建一个儿子,连一条标号为1的边;下一次插入GAIN,算得它与GAME的距离为2,于是放在编号为2的边下。再下次我们插入GATE,它与GAME距离为1,于是沿着那条编号为1的边下去,递归地插入到FAME所在子树;GATE与FAME的距离为2,于是把GATE放在FAME节点下,边的编号为2。
BK查询
如果我们需要返回与错误单词距离不超过n的单词,这个错误单词与树根所对应的单词距离为d,那么接下来我们只需要递归地考虑编号在d-n到d+n范围内的边所连接的子树。由于n通常很小,因此每次与某个节点进行比较时都可以排除很多子树。
推论如下:
d(query, B) + d(B, A) >= d(query, A), 即 d(query, B) + d(A,B) >= d
--> d(A,B) >= d - d(query, B) >= d - n
d(A, B) <= d(A,query) + d(query, B), 即 d(A, B) <= d + d(query, B) <= d + n
其实,还可以得到 d(query, A) + d(A,B) >= d(query, B)
--> d(A,B) >= d(query, B) - d(query, A)
--> d(A,B) >= 1 - d >= 0 (query与B不等) 由于 A与B不是同一个字符串,所以d(A,B)>=1
所以, min{1, d - n} <= d(A,B) <= d + n,这是更为完整的结论。
实现:
整体实现需要有构建树与查询树两块功能,查询时需要返回编辑距离与节点的字符串
1.首先实现编辑距离计算方法
def calculate_edit_distance(word1, word2):
len1 = len(word1)
len2 = len(word2)
dp = np.zeros((len1 + 1, len2 + 1))
for i in range(len1 + 1):
dp[i][0] = i
for j in range(len2 + 1):
dp[0][j] = j
for i in range(1, len1 + 1):
for j in range(1, len2 + 1):
delta = 0 if word1[i - 1] == word2[j - 1] else 1
dp[i][j] = min(dp[i - 1][j - 1] + delta, min(dp[i - 1][j] + 1, dp[i][j - 1] + 1))
return int(dp[len1][len2])
2.实现结果返回类
class ResultNode:
def __init__(self, data, distance):
self.data = data
self.distance = distance
3.实现节点类
class TreeNode:
def __init__(self, data):
self.data = data
self.child_node_dict = {}
def put(self, chars):
distance = ed.calculate_edit_distance(chars, self.data)
if distance == 0:
return
keys = self.child_node_dict.keys()
if distance in keys:
self.child_node_dict[distance].put(chars)
else:
self.child_node_dict[distance] = TreeNode(chars)
def query(self, target_char, n):
results = []
keys = self.child_node_dict.keys()
distance = ed.calculate_edit_distance(target_char, self.data)
if distance <= n:
results.append(ResultNode(self.data, distance))
if distance != 0:
for query_distance in range(max(distance - n, 1), distance + n + 1):
if query_distance not in keys:
continue
value_node = self.child_node_dict[query_distance]
results += value_node.query(target_char, n)
return results
def get_all_data(self):
results = []
keys = self.child_node_dict.keys()
values = self.child_node_dict.values()
results += [node.data for node in values]
for key in keys:
value_node = self.child_node_dict[key]
results += value_node.get_all_data()
return results
4.实现树类
class BKTree:
def __init__(self, root_chars):
self.root_node = TreeNode(root_chars)
def put(self, chars):
self.root_node.put(chars)
def query(self, target_char, n):
if self.root_node is None:
return ResultNode(target_char, 0)
else:
queries = self.root_node.query(target_char, n)
if len(queries) == 0:
return ResultNode(target_char, 0)
else:
queries.sort(key=lambda x: x.distance, reverse=False)
return queries[0]
def get_all_data(self):
if self.root_node is None:
return []
else:
return self.root_node.get_all_data()
5.实现树的保存和恢复
import pickle
import os
import random
from model import BKTree
from utils import read_dict
bk_tree_path = 'bk_tree.pkl'
def dump_bk_tree(bk_tree):
with open(bk_tree_path, 'wb') as f:
pickle.dump(bk_tree, f)
def load_bk_tree():
if os.path.exists(bk_tree_path):
print('load build tree')
with open(bk_tree_path, 'rb') as f:
return pickle.load(f)
else:
char_list = read_dict('dict_en.txt')
randint = random.randint(0, len(char_list) - 1)
bk_tree = BKTree(char_list[randint])
print('start build tree')
for index, item in enumerate(char_list):
print('build tree:' + str(index) + '/' + str(len(char_list)))
bk_tree.put(item)
dump_bk_tree(bk_tree)
return bk_tree
6.调用测试
from load_tree import load_bk_tree
from datetime import datetime
bk_tree = load_bk_tree()
query_word = 'lavishnessa'
be = datetime.now()
query = bk_tree.query(query_word, 3)
delta_time = datetime.now() - be
print("spent:" + str(delta_time))
print(query.data)
print(query.distance)
直接使用python实现效率非常慢,可以考虑使用cython加速计算编辑距离的逻辑,可以达到近15倍的加速效率。
7.Cython加速 calculate edit distance
from libc.stdlib cimport malloc, free
def calculate_edit_distance(word1, word2):
len1 = len(word1)
len2 = len(word2)
cdef int** dp = <int**> malloc((len1 + 1) * sizeof(int*))
for i in range(len1 + 1):
dp[i] = <int*> malloc((len2 + 1) * sizeof(int))
for i in range(len1 + 1):
dp[i][0] = i
for j in range(len2 + 1):
dp[0][j] = j
for i in range(1, len1 + 1):
for j in range(1, len2 + 1):
delta = 0 if word1[i - 1] == word2[j - 1] else 1
dp[i][j] = min(dp[i - 1][j - 1] + delta, min(dp[i - 1][j] + 1, dp[i][j - 1] + 1))
cdef result = dp[len1][len2]
for i in range(len1 + 1):
free(dp[i])
free(dp)
return result
cython编写和python实现几乎一致,此处只对于动态对象进行了内存管理