字典树是一种是十分强大的高级数据结构,常用于字典和文件目录。
首先是树节点的设计,我们是用一个数组(字母与数组索引的映射关系时常用的tricky操作)和值对象来表达一个树节点,而键对象是由沿着树的路径搜索字符拼接而成的,当某节点键对象为空时,表明当前路径组成的字符串不存在于字典中,不为空则存在。
举个例子:假设字母表只有ABCD四个字母,依次插入ABAB、ABBAB、BADA、BCDA、AB、ABBA、ABCD、BCD,结果如下,
代码如下:
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Queue;
public class TrieST<T> {
public static final int R=256;
public TrieNode root;
private int n=0;
private static class TrieNode{
private Object val;
private TrieNode[] next=new TrieNode[R];
}
public TrieST()
{
this.root=new TrieNode();
}
public void put(String key,T value)
{
if (key == null)
throw new IllegalArgumentException("argument to get() is null");
if(value==null)
delete(key);
this.put(this.root,key,value,0);
}
public TrieNode put(TrieNode node,String key,T value,int d)
{
if(node == null)
node = new TrieNode();
if(key.length() == d)
{
if(node.val==null)
this.n++;
node.val = value;
return node;
}
char c = key.charAt(d);
node.next[c] = put(node.next[c],key,value,d+1);
return node;
}
public boolean containsKey(String key)
{
if (key == null)
throw new IllegalArgumentException("argument to get() is null");
return get(key)!=null;
}
public T get(String key)
{
if (key == null)
throw new IllegalArgumentException("argument to get() is null");
TrieNode node = get(this.root,key,0);
if(node == null)
return null;
else
return (T)node.val;
}
public TrieNode get(TrieNode node,String key,int d)
{
if(node == null)
return null;
if(key.length() == d)
return node;
char c = key.charAt(d);
return get(node.next[c],key,d+1);
}
public void delete(String key)
{
if (key == null)
throw new IllegalArgumentException("argument to get() is null");
delete(root,key,0);
}
private TrieNode delete(TrieNode node, String key, int d)
{
if (node == null) return null;
if (d == key.length())
{
if (node.val != null)
this.n--;
node.val = null;
} else {
char c = key.charAt(d);
node.next[c] = delete(node.next[c], key, d+1);
}
if(node.val != null) return node;
for(int c = 0; c < R; c++)
{
if (node.next[c] != null)
return node;
}
return null;
}
public int size()
{
return this.n;
}
public boolean isEmpty()
{
return this.n==0;
}
/**
* ordered iteration
* Returns all keys in the symbol table as an {@code Iterable}.
* To iterate over all of the keys in the symbol table named {@code st},
* use the foreach notation: {@code for (Key key : st.keys())}.
* @return all keys in the symbol table as an {@code Iterable}
*/
public Iterable<String> keys()
{
Queue<String> q=new LinkedList<>();
collect(this.root,new StringBuilder(),q);
return q;
}
/*
* Returns all of the keys in the set that start with {@code prefix}.
* @param prefix the prefix
* @return all of the keys in the set that start with {@code prefix},
* as an iterable
*/
public Iterable<String> keysWithPrefix(String prefix)
{
Queue<String> q=new LinkedList<>();
TrieNode node=get(this.root,prefix,0);
collect(node,new StringBuilder(prefix),q);
return q;
}
private void collect(TrieNode node, StringBuilder prefix, Queue<String> results)
{
if(node==null)
return;
if(node.val!=null)
results.add(prefix.toString());
for(char i=0;i<R;i++)
{
prefix.append(i);
collect(node.next[i],prefix,results);
prefix.deleteCharAt(prefix.length()-1);
}
}
/**
* Returns all of the keys in the symbol table that match {@code pattern},
* where . symbol is treated as a wildcard character.
* @param pattern the pattern
* @return all of the keys in the symbol table that match {@code pattern},
* as an iterable, where . is treated as a wildcard character.
*/
public Iterable<String> keysThatMatch(String pattern)
{
Queue<String> q=new LinkedList<>();
collect(this.root,new StringBuilder(),pattern,q);
return q;
}
private void collect(TrieNode node, StringBuilder prefix, String pattern, Queue<String> results)
{
if (node == null) return;
int d = prefix.length();
if (d == pattern.length() && node.val != null)
results.add(prefix.toString());
if (d == pattern.length())
return;
char c = pattern.charAt(d);
if (c == '.') {
for (char ch = 0; ch < R; ch++) {
prefix.append(ch);
collect(node.next[ch], prefix, pattern, results);
prefix.deleteCharAt(prefix.length() - 1);
}
}
else {
prefix.append(c);
collect(node.next[c], prefix, pattern, results);
prefix.deleteCharAt(prefix.length() - 1);
}
}
/**
* Returns the string in the symbol table that is the longest prefix of {@code query},
* or {@code null}, if no such string.
* @param query the query string
* @return the string in the symbol table that is the longest prefix of {@code query},
* or {@code null} if no such string
* @throws IllegalArgumentException if {@code query} is {@code null}
*/
public String longestPrefixOf(String query)
{
if(query==null)
return null;
int length=longestPrefixOf(this.root,query,0,-1);
return query.substring(0,length);
}
// returns the length of the longest string key in the subtrie
// rooted at x that is a prefix of the query string,
// assuming the first d character match and we have already
// found a prefix match of given length (-1 if no such match)
private int longestPrefixOf(TrieNode node, String query, int d, int length)
{
if(node==null)
return length;
if(node.val!=null)
length=d;
if (d == query.length()) return length;
char c = query.charAt(d);
return longestPrefixOf(node.next[c], query, d+1, length);
}
public static void main(String[] args)
{
TrieST<String> st=new TrieST<>();
System.out.println(st.size());
String[] strs={"AB","ABBA","ABCD","BCD","ABAB","ABBAB","BADA","BCDA"};
for(int i=0;i<strs.length;i++)
{
if(strs[i]!=null)
{
System.out.println(strs[i]);
st.put(strs[i],strs[i]);
}
}
System.out.println(st.containsKey("ABAB"));
Iterator<String> iter1=st.keys().iterator();
while (iter1.hasNext())
System.out.println(iter1.next());
System.out.println(st.containsKey("BADC"));
st.put("ABCC","ABCC");
st.put("CBAC","CBAC");
Iterator<String> iter2=st.keysThatMatch("AB").iterator();
while (iter2.hasNext())
System.out.println(iter2.next());
System.out.println(st.containsKey("BADC"));
Iterator<String> iter3=st.keysWithPrefix("AB").iterator();
while (iter3.hasNext())
System.out.println(iter3.next());
System.out.println("hahha");
}
}
很明显这种数据结构还存在着大量空指针的内存浪费问题。
后来,我在刷LeetCode时发现了一种稍微简洁的节点定义方式,链接https://leetcode.com/problems/implement-trie-prefix-tree/solution/
每个节点中含有一个逻辑变量isEnd来存储是否为单词终点字符的信息,节点定义如下:
class TrieNode {
// R links to node children
private TrieNode[] links;
private final int R = 26;
private boolean isEnd;
public TrieNode() {
links = new TrieNode[R];
}
public boolean containsKey(char ch) {
return links[ch -'a'] != null;
}
public TrieNode get(char ch) {
return links[ch -'a'];
}
public void put(char ch, TrieNode node) {
links[ch -'a'] = node;
}
public void setEnd() {
isEnd = true;
}
public boolean isEnd() {
return isEnd;
}
}
查询和插入操作都是在节点层次上完成。
Trie的定义:
class Trie {
private TrieNode root;
public Trie() {
root = new TrieNode();
}
// Inserts a word into the trie.
public void insert(String word) {
TrieNode node = root;
for (int i = 0; i < word.length(); i++) {
char currentChar = word.charAt(i);
if (!node.containsKey(currentChar)) {
node.put(currentChar, new TrieNode());
}
node = node.get(currentChar);
}
node.setEnd();
}
// search a prefix or whole key in trie and
// returns the node where search ends
private TrieNode searchPrefix(String word)
{
TrieNode node = root;
for (int i = 0; i < word.length(); i++) {
char curLetter = word.charAt(i);
if (node.containsKey(curLetter)) {
node = node.get(curLetter);
} else {
return null;
}
}
return node;
}
// Returns if the word is in the trie.
public boolean search(String word) {
TrieNode node = searchPrefix(word);
return node != null && node.isEnd();
}
public boolean startsWith(String prefix) {
TrieNode node = searchPrefix(prefix);
return node != null;
}
}
这种数据结构设计方式很值得学习和借鉴,因为比起上一种设计,它的插入操作十分简洁易懂。