题目来源:力扣
题目描述:
实现一个 MapSum 类里的两个方法,insert 和 sum。
对于方法 insert,你将得到一对(字符串,整数)的键值对。字符串表示键,整数表示值。如果键已经存在,那么原来的键值对将被替代成新的键值对。
对于方法 sum,你将得到一个表示前缀的字符串,你需要返回所有以该前缀开头的键的值的总和。
=================================================
示例 1:
输入: insert(“apple”, 3), 输出: Null
输入: sum(“ap”), 输出: 3
输入: insert(“app”, 2), 输出: Null
输入: sum(“ap”), 输出: 5
==================================================
审题:
对于本题,使用单词查找树设计算法.因此该问题可以分解为两个子问题(1)插入字符串构建单词查找树.(2)搜索所有以给定字符串为前缀的字符串并累加其值.
此处,我们使用三向查找树结构实现单词查找树.三项单词查找树的插入构建此处不再细讲,只分析在三向查找树结构中如果搜索所有以给定字符串为前缀的字符串.
首先,我们在三向查找树中查找给定前缀字符串最后一个字符对应的树节点,如果未搜索到,则返回null.假设当前节点N为前缀字符串最后一个字符对应的节点,则以该字符串为前缀的所有字符串可能包括:该前缀字符串,该节点mid链接往下搜索查找到的所有字符串.因此沿mid链接搜索所有字符串并累加值即可.
java算法:
//实现三向查找树
class MapSum {
class Node{
char c;
Node left;
Node right;
Node mid;
Integer val;
Node(char c, Integer val){
this.c = c;
this.val = val;
}
}
private Node root;
/** Initialize your data structure here. */
public MapSum() {
}
private Node insert(Node x, String key, int val, int d){
if(x == null){
x = new Node(key.charAt(d), null);
}
if(x.c > key.charAt(d)){
x.left = insert(x.left, key, val, d);
}
else if(x.c < key.charAt(d)){
x.right = insert(x.right, key, val, d);
}
else if(d < key.length() - 1)
x.mid = insert(x.mid, key, val, d+1);
else
x.val = val;
return x;
}
public void insert(String key, int val) {
root = insert(root, key, val, 0);
}
// 搜索prefix对应的终止节点
private Node searchNode(Node x, String prefix, int d){
if(x == null)
return null;
if(x.c > prefix.charAt(d))
return searchNode(x.left, prefix, d);
else if(x.c < prefix.charAt(d))
return searchNode(x.right, prefix, d);
else if(d < prefix.length()-1)
return searchNode(x.mid, prefix, d+1);
else
return x;
}
private int sum(Node x){
if(x == null)
return 0;
int sumVal = 0;
if(x.val != null)
sumVal += x.val;
return sumVal + sum(x.left) + sum(x.mid) + sum(x.right);
}
//以prefix起始的字符串包括prefix(如果prefix存在与前缀树中)以及mid节点之后的所有节点
public int sum(String prefix) {
Node start = searchNode(root, prefix, 0);
if(start == null)
return 0;
int sumVal = 0;
if(start.val != null)
sumVal += start.val;
return sumVal + sum(start.mid);
}
}