此程序对原实现进行了泛化,并对findMax 和findMin两个方法做了修改,使得使用起来方便了许多,大家可参见原作者的文章点击打开链接
附上原代码:
import java.util.ArrayList;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
/**
* Splay tree and its operations
*
* @author clj 2014-03-31
*
* 1. findXXX() methods are mutable operation 2. All operations are not
* thread-safe because the tree can be changed by any operation.
*
*/
public class SplayTreeTemplate<K, V> {
static class SplayNode<K, V> {
K key;
V val;
SplayNode<K, V> left;
SplayNode<K, V> right;
public SplayNode(K key, V val) {
this.key = key;
this.val = val;
}
public String toString() {
return String.valueOf(val);
}
}
private Comparator<? super K> comparator;
private SplayNode<K, V> root;
private int count = 0;
public SplayTreeTemplate(SplayNode<K, V> root, Comparator<? super K> c) {
this.root = root;
comparator = c;
count = countOfNodes(root);
}
public SplayTreeTemplate(Comparator<? super K> c) {
this.comparator = c;
}
private int countOfNodes(SplayNode<K, V> root) {
if (root == null)
return 0;
else {
return countOfNodes(root.left) + countOfNodes(root.right) + 1;
}
}
public SplayNode<K, V> getRoot() {
return root;
}
// find a node in the tree
// return true if found; false otherwise
public boolean find(K key) throws Exception {
if (root == null)
throw new Exception("tree is empty.");
splay(key);
if (comparator.compare(root.key, key) == 0) {
return true;
} else {
return false;
}
}
private SplayNode<K, V> findMax(SplayNode<K, V> rootNode) throws Exception {
if (rootNode == null)
throw new Exception("tree or subtree is empty.");
SplayNode<K, V> pseudoNode = new SplayNode<K, V>(null, null);
// right tree root (no right child)
SplayNode<K, V> leftMax = pseudoNode;
SplayNode<K, V> t = rootNode;
while (true) {
SplayNode<K, V> parent = t.right;
// if parent == null ,t is maximum
if (parent == null)
break;
// Note: the variable parent is target's parent, the variable t is
// target's grandparent
if (parent.right == null) {
// zag
t.right = null;
leftMax.right = t;
leftMax = t;
t = parent;
break;
} else {
// zag-zag
SplayNode<K, V> tmp = parent.right;
// after rotate parent.right = null
rotateRightChild(t, parent);
// update left tree and its max node
leftMax.right = parent;
leftMax = parent;
// update the middle tree's root
t = tmp;
}
}
leftMax.right = t.left;
t.left = pseudoNode.right; // pseudoNode.right is the root of left tree
return t;
}
public V findMax() throws Exception {
root = findMax(this.root);
return root.val;
}
private SplayNode<K, V> findMin(SplayNode<K, V> rootNode) throws Exception {
if (rootNode == null)
throw new Exception("tree or subtree is empty.");
SplayNode<K, V> pseudoNode = new SplayNode<K, V>(null, null);
// right tree root (no right child)
SplayNode<K, V> rightMin = pseudoNode;
SplayNode<K, V> t = rootNode;
while (true) {
SplayNode<K, V> parent = t.left;
// if parent == null ,t is minimum
if (parent == null)
break;
// Note: the variable parent is target's parent, the variable t is
// target's grandparent
if (parent.left == null) {
// zig
t.left = null;
rightMin.left = t;
rightMin = t;
t = parent;
break;
} else {
// zig-zig
SplayNode<K, V> tmp = parent.left;
// after rotate parent.left = null
rotateLeftChild(t, parent);
// update right tree and its min node
rightMin.left = parent;
rightMin = parent;
// update the middle tree's root
t = tmp;
}
}
rightMin.left = t.right;
t.right = pseudoNode.left; // pseudoNode.left is the root of right tree
return t;
}
public V findMin() throws Exception {
root = findMin(this.root);
return root.val;
}
public V deleteMax() throws Exception {
V max = findMax();
this.root = root.left;
count--;
return max;
}
public V deleteMin() throws Exception {
V min = findMin();
this.root = root.right;
count--;
return min;
}
public void insert(K key, V val) throws Exception {
if (root == null) {
// set the new node as root
this.root = new SplayNode<K, V>(key, val);
count++;
} else {
splay(key);
if (comparator.compare(root.key, key) == 0) {
root.val = val;
// throw new Exception("duplicate value!");
} else if (comparator.compare(key, root.key) < 0) {
// split the splayed tree with right subtree including root, and
// set the new node as root
// x is between root and root.left
SplayNode<K, V> tmp = new SplayNode<K, V>(key, val);
tmp.left = this.root.left;
tmp.right = this.root;
root.left = null;
this.root = tmp;
count++;
} else {// ie. x>root.val
// split the splayed tree with left subtree including root,
// and set the new Node<K,V> as root
// x is between root and root.right
SplayNode<K, V> tmp = new SplayNode<K, V>(key, val);
tmp.left = this.root;
tmp.right = this.root.right;
root.right = null;
this.root = tmp;
count++;
}
}
}
public V remove(K key) throws Exception {
if (root == null)
throw new Exception("tree is empty.");
splay(key);
if (comparator.compare(root.key, key) != 0) {
throw new Exception("value not found.");
}
SplayNode<K, V> temp = root;
if (root.left == null) {
// root(root.val==x) is the min node
root = root.right;
} else {
// find the max value from left subtree, and
// then remove root and join the right subtree with the left splayed
// subtree
SplayNode<K, V> leftSubTreeRoot = this.findMax(this.root.left);
leftSubTreeRoot.right = this.root.right;
root = leftSubTreeRoot;
}
count--;
return temp.val;
}
private void rotateLeftChild(SplayNode<K, V> grandparent,
SplayNode<K, V> parent) {
grandparent.left = parent.right;
parent.right = grandparent;
// split the parent with middle tree
parent.left = null;
}
private void rotateRightChild(SplayNode<K, V> grandparent,
SplayNode<K, V> parent) {
grandparent.right = parent.left;
parent.left = grandparent;
// split the parent with middle tree
parent.right = null;
}
// x: the target value to be found for splaying
public void splay(K key) {
this.root = splay(this.root, key);
}
// x: the target value to be found for splaying
// rootNode: the root node of the tree to be splayed
// return the new root of the splayed tree or subtree
private SplayNode<K, V> splay(SplayNode<K, V> rootNode, K key) {
if (rootNode == null)
return null;
SplayNode<K, V> pseudoNode = new SplayNode<K, V>(null, null);
// left tree root (no left child)
SplayNode<K, V> leftMax = pseudoNode;
// right tree root (no right child)
SplayNode<K, V> rightMin = pseudoNode;
SplayNode<K, V> t = rootNode;
while (true) {
int comp = comparator.compare(key, t.key);
if (comp == 0) { // key == t.key
break;
} else if (comp < 0) { // key <t.key
// Note: the variable parent is target's parent, the variable t
// is target's grandparent
SplayNode<K, V> parent = t.left;
if (parent == null) {
break;
} else {
if (comparator.compare(key, parent.key) < 0) {
if (parent.left == null) {
// zig
t.left = null;
rightMin.left = t;
rightMin = t;
t = parent;
break;
} else {
// zig-zig
SplayNode<K, V> tmp = parent.left;
// after rotate parent.left = null
rotateLeftChild(t, parent);
// update right tree and its min node
rightMin.left = parent;
rightMin = parent;
// update the middle tree's root
t = tmp;
}
} else { // ie. key >= parent.key
// zig or zig-zag(simplified to zig)
t.left = null;
rightMin.left = t;
rightMin = t;
t = parent;
}
}
} else { // ie. key > t.key
SplayNode<K, V> parent = t.right;
if (parent == null) {
break;
} else {
if (comparator.compare(key, parent.key) > 0) {
if (parent.right == null) {
// zag
t.right = null;
leftMax.right = t;
leftMax = t;
t = parent;
break;
} else {
// zag-zag
SplayNode<K, V> tmp = parent.right;
// after rotate parent.right = null
rotateRightChild(t, parent);
// update left tree and its max node
leftMax.right = parent;
leftMax = parent;
// update the middle tree's root
t = tmp;
}
} else { // ie. key <= parent.key
// zag or zag-zig (simplified to zag)
t.right = null;
leftMax.right = t;
leftMax = t;
t = parent;
}
}
}
}
// re-assemble (note: even if the above while is not executed, the
// following code works as expected.)
leftMax.right = t.left;
rightMin.left = t.right;
t.left = pseudoNode.right; // pseudoNode.right is the root of left tree
t.right = pseudoNode.left; // pseudoNode.left is the root of right tree
return t;
}
public int getSize() {
return count;
}
public boolean isEmpty() {
return count == 0;
}
// utility method for test purpose
public static void recursiveInOrderTraverse(SplayNode root) {
if (root == null)
return;
recursiveInOrderTraverse(root.left);
System.out.format(" %s", root.val);
recursiveInOrderTraverse(root.right);
}
// utility method for test purpose
// n: the nodes number of the tree
public static void displayBinaryTree(SplayNode root, int n) {
if (root == null)
return;
LinkedList<SplayNode> queue = new LinkedList<SplayNode>();
// all nodes in each level
List<List<SplayNode>> nodesList = new ArrayList<List<SplayNode>>();
// the positions in a displayable tree for each level's nodes
List<List<Integer>> nextPosList = new ArrayList<List<Integer>>();
queue.add(root);
// int level=0;
int levelNodes = 1;
int nextLevelNodes = 0;
List<SplayNode> levelNodesList = new ArrayList<SplayNode>();
List<Integer> nextLevelNodesPosList = new ArrayList<Integer>();
int pos = 0; // the position of the current node
List<Integer> levelNodesPosList = new ArrayList<Integer>();
levelNodesPosList.add(0); // root position
nextPosList.add(levelNodesPosList);
int levelNodesTotal = 1;
while (!queue.isEmpty()) {
SplayNode node = queue.remove();
if (levelNodes == 0) {
nodesList.add(levelNodesList);
nextPosList.add(nextLevelNodesPosList);
levelNodesPosList = nextLevelNodesPosList;
levelNodesList = new ArrayList<SplayNode>();
nextLevelNodesPosList = new ArrayList<Integer>();
// level++;
levelNodes = nextLevelNodes;
levelNodesTotal = nextLevelNodes;
nextLevelNodes = 0;
}
levelNodesList.add(node);
pos = levelNodesPosList.get(levelNodesTotal - levelNodes);
if (node.left != null) {
queue.add(node.left);
nextLevelNodes++;
nextLevelNodesPosList.add(2 * pos);
}
if (node.right != null) {
queue.add(node.right);
nextLevelNodes++;
nextLevelNodesPosList.add(2 * pos + 1);
}
levelNodes--;
}
// save the last level's nodes list
nodesList.add(levelNodesList);
int maxLevel = nodesList.size() - 1; // ==level
// use both nodesList and nextPosList to set the positions for each node
// Note: expected max columns: 2^(level+1) - 1
int cols = 1;
for (int i = 0; i <= maxLevel; i++) {
cols <<= 1;
}
cols--;
SplayNode[][] tree = new SplayNode[maxLevel + 1][cols];
// load the tree into an array for later display
for (int currLevel = 0; currLevel <= maxLevel; currLevel++) {
levelNodesList = nodesList.get(currLevel);
levelNodesPosList = nextPosList.get(currLevel);
// Note: the column for this level's j-th element:
// 2^(maxLevel-level)*(2*j+1) - 1
int tmp = maxLevel - currLevel;
int coeff = 1;
for (int i = 0; i < tmp; i++) {
coeff <<= 1;
}
for (int k = 0; k < levelNodesList.size(); k++) {
int j = levelNodesPosList.get(k);
int col = coeff * (2 * j + 1) - 1;
tree[currLevel][col] = levelNodesList.get(k);
}
}
// display the binary search tree
System.out.format("%n");
for (int i = 0; i <= maxLevel; i++) {
for (int j = 0; j < cols; j++) {
SplayNode node = tree[i][j];
if (node == null)
System.out.format(" ");
else
System.out.format("%2d", node.key);
}
System.out.format("%n");
}
}
public static void printAfterSplayed(SplayTreeTemplate splayTree) {
SplayNode root = splayTree.getRoot();
System.out.format("%nAfter being splayed, in-order BST:%n");
SplayTreeTemplate.recursiveInOrderTraverse(root);
System.out.format("%n%n%nAfter being splayed, the tree is:");
SplayTreeTemplate.displayBinaryTree(root, splayTree.getSize());
}
public static void main(String[] args) throws Exception {
Comparator<Integer> comparator = new Comparator<Integer>() {
@Override
public int compare(Integer o1, Integer o2) {
return o1 < o2 ? -1 : (o1 == o2 ? 0 : 1);
}
};
// test1(comparator);
// test2(comparator);
// test3(comparator);
test4(comparator);
}
private static void test4(Comparator<Integer> comparator) throws Exception {
SplayTreeTemplate<Integer, String> splayTree;
System.out.format("************************************");
System.out.format("%nTest case 4 - priority queue:%n");
splayTree = new SplayTreeTemplate<Integer, String>(comparator);
long current = System.currentTimeMillis();
for (int i = 0; i < 1000000; i++) {
Integer key = i + new Random().nextInt(10);
// Integer key = i;
String val = "v_" + key;
splayTree.insert(key, val);
}
long duration = System.currentTimeMillis() - current;
System.out.println("完成插入: " + duration);
current = System.currentTimeMillis();
for (int i = 0; i < 1000000; i++) {
Integer key = i + new Random().nextInt(10);
// Integer key = i;
splayTree.find(key);
}
duration = System.currentTimeMillis() - current;
System.out.println("完成查询:" + duration);
current = System.currentTimeMillis();
while (!splayTree.isEmpty()) {
splayTree.deleteMax();
}
duration = System.currentTimeMillis() - current;
System.out.println("完成删除:" + duration);
}
private static void test3(Comparator<Integer> comparator) throws Exception {
SplayNode<Integer, String> root;
SplayTreeTemplate<Integer, String> splayTree;
String max;
Integer newKey;
String newVal;
System.out.format("************************************");
System.out.format("%nTest case 3 - priority queue:%n");
SplayNode<Integer, String> m1 = new SplayNode<Integer, String>(1, "1");
SplayNode<Integer, String> m4 = new SplayNode<Integer, String>(4, "4");
SplayNode<Integer, String> m7 = new SplayNode<Integer, String>(7, "7");
SplayNode<Integer, String> m9 = new SplayNode<Integer, String>(9, "9");
SplayNode<Integer, String> m20 = new SplayNode<Integer, String>(20,
"20");
SplayNode<Integer, String> m22 = new SplayNode<Integer, String>(22,
"22");
SplayNode<Integer, String> m26 = new SplayNode<Integer, String>(26,
"26");
SplayNode<Integer, String> m29 = new SplayNode<Integer, String>(29,
"29");
SplayNode<Integer, String> m30 = new SplayNode<Integer, String>(30,
"30");
SplayNode<Integer, String> m36 = new SplayNode<Integer, String>(36,
"36");
m1.right = m4;
m4.right = m7;
m7.right = m9;
m9.right = m20;
m20.right = m22;
m22.right = m26;
m26.right = m29;
m29.right = m30;
m30.right = m36;
root = m1;
System.out.format("%nBefore being splayed, in-order BST:%n");
SplayTreeTemplate.recursiveInOrderTraverse(root);
splayTree = new SplayTreeTemplate(root, comparator);
max = splayTree.deleteMax();
System.out.format("%n*****deleted max value: %s*****%n", max);
printAfterSplayed(splayTree);
max = splayTree.deleteMax();
System.out.format("%n*****deleted max value: %s*****%n", max);
printAfterSplayed(splayTree);
max = splayTree.deleteMax();
System.out.format("%n*****deleted max value: %s*****%n", max);
printAfterSplayed(splayTree);
max = splayTree.deleteMax();
System.out.format("%n*****deleted max value: %s*****%n", max);
printAfterSplayed(splayTree);
newKey = 16;
newVal = "16";
splayTree.insert(newKey, newVal);
System.out.format("%n*****insert new value %d*****%n", newKey);
printAfterSplayed(splayTree);
max = splayTree.deleteMax();
System.out.format("%n*****deleted max value: %s*****%n", max);
printAfterSplayed(splayTree);
max = splayTree.deleteMax();
System.out.format("%n*****deleted max value: %s*****%n", max);
printAfterSplayed(splayTree);
newKey = 12;
newVal = "12";
splayTree.insert(newKey, newVal);
System.out.format("%n*****insert new value %d*****%n", newKey);
printAfterSplayed(splayTree);
max = splayTree.deleteMax();
System.out.format("%n*****deleted max value: %s*****%n", max);
printAfterSplayed(splayTree);
}
private static void test2(Comparator<Integer> comparator) throws Exception {
SplayNode<Integer, String> root;
SplayTreeTemplate<Integer, String> splayTree;
System.out.format("************************************");
System.out.format("%nTest case 2 - splaytree's operations:%n");
/*
* 13 10 25 12 20 35 29
*/
SplayNode<Integer, String> n13 = new SplayNode<Integer, String>(13,
"13");
SplayNode<Integer, String> n10 = new SplayNode<Integer, String>(10,
"10");
SplayNode<Integer, String> n25 = new SplayNode<Integer, String>(25,
"25");
SplayNode<Integer, String> n12 = new SplayNode<Integer, String>(12,
"12");
SplayNode<Integer, String> n20 = new SplayNode<Integer, String>(20,
"20");
SplayNode<Integer, String> n35 = new SplayNode<Integer, String>(35,
"35");
SplayNode<Integer, String> n29 = new SplayNode<Integer, String>(29,
"29");
n13.left = n10;
n13.right = n25;
n10.right = n12;
n25.left = n20;
n25.right = n35;
n35.left = n29;
root = n13;
System.out.format("%nBefore being splayed, in-order BST:%n");
SplayTreeTemplate.recursiveInOrderTraverse(root);
splayTree = new SplayTreeTemplate<Integer, String>(root, comparator);
int val = 25;
boolean found = splayTree.find(val);
System.out.format("%n*****%d is in the tree? [%s]*****%n", val, found);
printAfterSplayed(splayTree);
String max = splayTree.findMax();
System.out.format("%n*****max value=%s*****%n", max);
printAfterSplayed(splayTree);
String min = splayTree.findMin();
System.out.format("%n*****min value=%s*****%n", min);
printAfterSplayed(splayTree);
max = splayTree.deleteMax();
System.out.format("%n*****deleted max value: %s*****%n", max);
printAfterSplayed(splayTree);
min = splayTree.deleteMin();
System.out.format("%n*****deleted min value: %s*****%n", min);
printAfterSplayed(splayTree);
Integer newKey = 24;
String newVal = "24";
splayTree.insert(newKey, newVal);
System.out.format("%n*****insert new value %d*****%n", newKey);
printAfterSplayed(splayTree);
Integer removeVal = 12;
splayTree.remove(removeVal);
System.out.format("%n*****remove value %d*****%n", removeVal);
printAfterSplayed(splayTree);
}
private static void test1(Comparator<Integer> comparator) {
System.out.format("%nTest case 1 - splay opeartion:%n");
SplayNode<Integer, String> nn12 = new SplayNode<Integer, String>(12, "12");
SplayNode<Integer, String> nn5 = new SplayNode<Integer, String>(5, "5");
SplayNode<Integer, String> nn25 = new SplayNode<Integer, String>(25, "25");
SplayNode<Integer, String> nn20 = new SplayNode<Integer, String>(20, "20");
SplayNode<Integer, String> nn30 = new SplayNode<Integer, String>(30, "30");
SplayNode<Integer, String> nn15 = new SplayNode<Integer, String>(15, "15");
SplayNode<Integer, String> nn24 = new SplayNode<Integer, String>(24, "24");
SplayNode<Integer, String> nn13 = new SplayNode<Integer, String>(13, "13");
SplayNode<Integer, String> nn18 = new SplayNode<Integer, String>(18, "18");
SplayNode<Integer, String> nn16 = new SplayNode<Integer, String>(16, "16");
nn12.left = nn5;
nn12.right = nn25;
nn25.left = nn20;
nn25.right = nn30;
nn20.left = nn15;
nn20.right = nn24;
nn15.left = nn13;
nn15.right = nn18;
nn18.left = nn16;
SplayNode<Integer, String> root = nn12;
System.out.format("%nBefore being splayed, in-order BST:%n");
SplayTreeTemplate.recursiveInOrderTraverse(root);
SplayTreeTemplate<Integer, String> splayTree = new SplayTreeTemplate(
root, comparator);
splayTree.splay(19);
System.out.format("%n*****splay the node with value=19*****%n");
printAfterSplayed(splayTree);
}
}