这一篇是关于三向Trie的数据结构设计,该数据结构克服了前两篇Trie实现出现大量空引用浪费内存的问题,但是算法更加复杂。
树节点设计包括一个字符变量(表示该节点的字符),单词对象(可以为null)和三个指向树节点的引用指针,代码如下:
private static class TernaryTrieNode<T>{
private T val;
private char ch;
private TernaryTrieNode<T> left,mid,right;
}
整个树的设计包括一个根节点(应用)和树的节点数两个变量,代码如下:
public class TernaryTrieST<T>{
public TernaryTrieNode<T> root;
public int size;
}
举个例子,字母表由ABCD组成,插入AB、ABAB、ABBAB、ABCD、BADA、BCDA、BBCA,如图:
详细代码如下:
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Queue;
public class TernaryTrieST<T>{
public TernaryTrieNode<T> root;
public int size;
private static class TernaryTrieNode<T>{
private T val;
private char ch;
private TernaryTrieNode<T> left,mid,right;
}
public TernaryTrieST()
{
this.root=new TernaryTrieNode();
this.size=0;
}
//向整棵树表示的符号表中插入键值对
public void put(String key, T val) {
if (key == null) {
throw new IllegalArgumentException("calls put() with null key");
}
if (!containsKey(key))
size++;
root = put(root, key, val, 0);
}
//以当前节点为根节点的子树表示的符号表中插入键值对,d表示字符串key中的字符索引
private TernaryTrieNode<T> put(TernaryTrieNode<T> node, String key, T val, int d)
{
char c = key.charAt(d);
if (node == null)
{
node = new TernaryTrieNode<>();
node.ch = c;
}
if(c < node.ch)
node.left = put(node.left, key, val, d);
else if (c > node.ch)
node.right = put(node.right, key, val, d);
else if (d < key.length() - 1)
node.mid = put(node.mid, key, val, d+1);
else
node.val = val;
return node;
}
public boolean containsKey(String key)
{
return get(key)!=null;
}
public T get(String key)
{
if (key == null)
throw new IllegalArgumentException("argument to get() is null");
TernaryTrieNode node=this.get(this.root,key,0);
if(node!=null)
return (T)node.val;
else
return null;
}
public TernaryTrieNode get(TernaryTrieNode node,String key,int d)
{
if(node==null)
return null;
if (key.length() == 0) throw new IllegalArgumentException("key must have length >= 1");
char c=key.charAt(d);
if(c<node.ch)
return get(node.left,key,d);
else if(c>node.ch)
return get(node.right,key,d);
else if(d<key.length()-1)
return get(node.mid,key,d+1);
else
return node;
}
public int size()
{
return this.size;
}
public boolean isEmpty()
{
return this.size!=0;
}
public Iterable<String> keys()
{
Queue<String> q=new LinkedList<>();
collect(this.root,new StringBuilder(),q);
return q;
}
private void collect(TernaryTrieNode<T> node, StringBuilder prefix, Queue<String> queue)
{
if(node==null)
return;
collect(node.left,prefix,queue);
if(node.val!=null)
queue.add(prefix.toString()+node.ch);
collect(node.mid,prefix.append(node.ch),queue);
prefix.deleteCharAt(prefix.length()-1);
collect(node.right,prefix,queue);
}
//返回存在于符号表中且为字符串query的前缀的最长单词字符串
public String longestPrefixOf(String query)
{
if(query==null)
throw new IllegalArgumentException("calls longestPrefixOf() with null argument");
int length=longestPrefixOf(this.root,query,0,-1);
return query.substring(0,length);
}
//典型的回溯法,返回存在于“以输入节点node为根节点的子树表达的符号表”中且为字符串query的前缀的最长单词字符串
private int longestPrefixOf(TernaryTrieNode node, String query, int d, int length)
{
if(node==null)
return length;
if(node.val!=null)
length=d;
if (d == query.length()) return length;
char c = query.charAt(d);
if(c>node.ch)
return longestPrefixOf(node.right,query,d,length);
else if(c<node.ch)
return longestPrefixOf(node.left,query,d,length);
else
return longestPrefixOf(node.mid, query, d+1, length);
}
//返回存在于整个符号表中且以prefix为前缀的所有字符串的集合
public Iterable<String> keysWithPrefix(String prefix)
{
if (prefix == null)
{
throw new IllegalArgumentException("calls keysWithPrefix() with null argument");
}
Queue<String> queue = new LinkedList<>();
TernaryTrieNode<T> node = get(root, prefix, 0);
if (node == null)
return queue;
if (node.val != null)
queue.add(prefix);
collect(node.mid, new StringBuilder(prefix), queue);
return queue;
}
//返回存在于整个符号表中且匹配上模式串pattern的所有字符串的集合
public Iterable<String> keysThatMatch(String pattern)
{
Queue<String> q=new LinkedList<>();
collect(this.root,new StringBuilder(),0,pattern,q);
return q;
}
private void collect(TernaryTrieNode<T> node, StringBuilder prefix, int i, String pattern, Queue<String> queue)
{
if(node == null) return;
char c = pattern.charAt(i);
if (c == '.' || c < node.ch)
collect(node.left, prefix, i, pattern, queue);
if (c == '.' || c == node.ch)
{
if (i == pattern.length() - 1 && node.val != null)
queue.add(prefix.toString() + node.ch);
if (i < pattern.length() - 1)
{
collect(node.mid, prefix.append(node.ch), i+1, pattern, queue);
prefix.deleteCharAt(prefix.length() - 1);
}
}
if (c == '.' || c > node.ch) collect(node.right, prefix, i, pattern, queue);
}
public static void main(String[] args)
{
TernaryTrieST<String> st=new TernaryTrieST<>();
String[] strs={"AB","ABBA","ABCD","BCD","ABAB","ABBAB","BADA","BCDA"};
for(int i=0;i<strs.length;i++)
{
if(strs[i]!=null)
{
System.out.println(strs[i]);
st.put(strs[i],strs[i]);
}
}
System.out.println(st.containsKey("ABAB"));
Iterator<String> iter1=st.keys().iterator();
while (iter1.hasNext())
System.out.println(iter1.next());
System.out.println(st.containsKey("BADC"));
st.put("ABCC","ABCC");
st.put("CBAC","CBAC");
Iterator<String> iter2=st.keysThatMatch("AB").iterator();
while (iter2.hasNext())
System.out.println(iter2.next());
System.out.println(st.containsKey("BADC"));
Iterator<String> iter3=st.keysWithPrefix("AB").iterator();
while (iter3.hasNext())
System.out.println(iter3.next());
System.out.println("hahha");
}
}
这个代码参考《算法四》的源码写出的,学习不少代码抽象封装的技巧,还有说明文档的书写规范等,里面的递归算法真的十分tricky,这里就用到了常用的回溯法(backtrace)。