用于根据编辑距离(Levenshtein距离)概念执行拼写检查。 BK树也用于近似字符串匹配。基于该数据结构,可以实现许多软件中的各种自动校正特征。
BKTree算法上分两步:
1 构造
在词典里面随便找一个词作为root节点,然后与其他词计算编辑距离n。若已有相同编辑距离n的子节点,就挂在子节点下;若没有,就新建边为n的子节点。如此递归下去。
2 查询
这里重点来了,编辑距离符合三角不等式:任意两条边的和大于第三条边。所以只用从根节点开始,找(d-n) < x < (d+n)的边。这样可以大量减少编辑距离的计算次数,即减少O(D*m*n)中的D。
3 举例
一构造那张图为例,假如我们输入一个GAIE,程序发现它不在字典中。现在,我们想返回字典中所有与GAIE距离为1的单词。我们首先将GAIE与树根进行比较, 得到的距离d=1。由于Levenshtein距离满足三角形不等式,因此现在所有离GAME距离超过2的单词全部可以排除了
class TreeNode{
private:
string key_;
map<int, TreeNode*> *children_; // {distance: child}
public:
// 构造函数
explicit TreeNode(const string &key): key_(key), children_(NULL) {}
// 析构函数
~TreeNode() {
if (children_) {
for (auto it = children_->begin(); it != children_->end(); it ++) {
delete it->second;
}
delete children_;
}
}
// 新建树节点
bool Add(TreeNode *node) {
if (!node) {
return false;
}
int distance = levenshtein_distance(key_, node->key_);
if (!children_) {
children_ = new std::map<int, TreeNode*>();
}
// 寻找是否存在相同距离的节点,存在以此节点作为父节点,否则新建节点
auto it = children_->find(distance);
if (it == children_->end()) {
children_->insert(make_pair(distance, node));
return true;
} else {
return it->second->Add(node);
}
}
// 只用从根节点开始,找(d-n) < x < (d+n)的边。这样可以大量减少编辑距离的计算次数,减少O(D*m*n)中的D。
void Find(vector<pair<int, string>> *found, const string &key, int max_distance) {
queue<TreeNode*> candidates;
candidates.push(this);
while (!candidates.empty()) {
TreeNode* candidate = candidates.front();
candidates.pop();
int distance = levenshtein_distance(candidate->key_, key);
if (distance <= max_distance) {
found->push_back(make_pair(distance, candidate->key_));
}
if (candidate->HasChildren()) {
for (auto it = candidate->children_->begin(); it != candidate->children_->end(); it ++) {
// 只保留在(d-n) < x < (d+n)距离内的节点
if ( (distance - max_distance) <= it->first && it->first <= (distance + max_distance) ) {
candidates.push(it->second);
}
}
}
}
}
bool HasChildren() {
return children_ && children_->size();
}
};
class BKTree{
private:
TreeNode *root_;
public:
BKTree(): root_(NULL) {}
// 增加节点
void Add(const string &key) {
TreeNode *node = new TreeNode(key);
if (!root_) {
root_ = node;
} else {
root_->Add(node);
}
}
// 查找节点
vector<pair<int, string>> Find(const string &key, int max_distance) {
vector<pair<int, string>> found;
if (root_) {
root_->Find(&found, key, max_distance);
}
return found;
}
// 通过文件添加节点
int BuildFromFile(const string &filepath) {
cerr << "load vocab file: " << filepath << endl;
ifstream f(filepath);
string s, a, b, c;
int count = 0;
while (getline(f, s)) {
istringstream line(s);
// id word freq
line >> a >> b >> c;
if (count % 10000 == 0) {
cout << count << ": " << b << endl;
}
Add(b);
count++;
}
return count;
}
};