C++代码
#include<cstring>
class TrieNode{
public:
const int static MaxBranchNum = 26;
int pass;
int end;
string word;
TrieNode* next[MaxBranchNum];
TrieNode()
{
pass = 0;
end = 0;
word = "";
memset(next, NULL, sizeof(TrieNode*)*MaxBranchNum);
}
};
class TrieTree{
private:
TrieNode*root;
public:
TrieTree();
~TrieTree();
void insert(string str);
int search(string str);
void delete(string str);
};
TrieTree::TrieTree(){
root = new TrieNode();
}
TrieTree::~TrieTree(){
}
void TrieTree::insert(string str){
if(str == "")
return;
TrieNode* node = root;
int index = 0;
char buf[str.size()];
strcpy(buf, str.c_str());
for(int i = 0; i < strlen(buf); i++){
index = buf[i] -'a';
if(node->next[index]->pass== 0)
{
node->next[index] = new TrieNode();
}
node = node->next[index];
node->pass++;
}
node->end++;
node->word = str;
}
int TrieTree::search(string str)
{
if(str.size() == 0)
return 0;
TrieNode* node = root;
int index = 0;
char buf[str.size()];
strcpy(buf, str.c_str());
for(int i = 0; i < strlen(buf); i++){
index = buf[i] -'a';
if(node->next[index]->pass== 0)
{
return 0;
}
node = node->next[index];
}
return node->end;
}
void TrieTree::delete(string str)
{
if(search(str)!= 0)
{
TrieNode* node = root;
int index = 0;
char buf[str.size()];
strcpy(buf, str.c_str());
for(int i = 0; i < strlen(buf); i++){
index = buf[i] -'a';
if(--node->next[index]->pass== 0)
{
node->next[index] = null;
return;
}
node = node->next[index];
}
node->end--;
}
}