废话不多说,直接上代码
public class Node<V> { /** * 该结点的key */ private Integer key; /** * 该结点的value */ private V value; /** * 左结点 */ private Node<V> left; /** * 右结点 */ private Node<V> right; /** * 父结点 根节点的父结点为null */ private Node<V> parentNode; private int size; public Node(Integer key, V value,Node<V> parentNode,int size) { this.key = key; this.value = value; this.parentNode = parentNode; this.size = size; } public Integer getKey() { return key; } public void setKey(Integer key) { this.key = key; } public V getValue() { return value; } public void setValue(V value) { this.value = value; } public Node<V> getLeft() { return left; } public void setLeft(Node<V> left) { this.left = left; } public Node<V> getRight() { return right; } public void setRight(Node<V> right) { this.right = right; } public Node<V> getParentNode() { return parentNode; } public void setParentNode(Node<V> parentNode) { this.parentNode = parentNode; } public int getSize() { return size; } public void setSize(int size) { this.size = size; }
public interface BinTree<V> { /** * 添加一个结点 * @param key * @param value */ void put(Integer key, V value); /** * 获取一个结点的值 * @param key * @return */ V get(Integer key); /** * 删除一个结点 * @param key * @return */ Node<V> remove(Integer key); /** * 统计根节点下的结点数 * @return */ int size(); /** * 统计指定节点下的结点数 * @return */ int size(Node node); /** * 二叉树的前序遍历 * @return */ List<Integer> beforeForeach(); /** * 二叉树的中序遍历 * @return */ List<Integer> centreForeach(); /** * 二叉树的后序遍历 * @return */ List<Integer> laterForeach();
public class BST<V> implements BinTree<V> { private Node<V> root = null; @Override public void put(Integer key, V value) { if(key == null){ throw new IllegalArgumentException("key is null"); } root = put(root,root,key,value); } @Override public V get(Integer key) { if(key == null){ throw new IllegalArgumentException("key is null"); } Node<V> node = get(root, key); if(node == null){ return null; } return node.getValue(); } /** * * 删除时这里需要分三种情况 * 1.删除的结点无子结点 * 2.删除的结点有1个子结点 左子树 / 右子树 * 3.删除的结点有2个子结点 * * @param key * @return */ @Override public Node<V> remove(Integer key) { if(key == null){ throw new IllegalArgumentException("key is null"); } Node<V> node = get(root, key); if(node == null){ return null; } delete(node); return node; } @Override public int size() { if(root == null){ return 0; } return root.getSize(); } @Override public int size(Node node) { if(node == null){ return 0; } return node.getSize(); } @Override public List<Integer> beforeForeach() { return beforeForeach(root); } @Override public List<Integer> centreForeach() { return centreForeach(root); } @Override public List<Integer> laterForeach() { return laterForeach(root); } /** * search node * @param node * @param key * @return */ private Node<V> get(Node<V> node,Integer key){ // search miss if(node == null){ return null; } if(key == null){ throw new IllegalArgumentException("key is null"); } if(node.getKey() > key){ return get(node.getLeft(),key); }else if(node.getKey() < key){ return get(node.getRight(),key); }else{ return node; } } /** * add source node * @param parentNode * @param node * @param key * @return */ private Node<V> put(Node<V> parentNode,Node<V> node,Integer key,V value){ if(key == null){ throw new IllegalArgumentException("key is null"); } // node not exist,init node now if(node == null){ return new Node<>(key,value,parentNode,1); } if(node.getKey() > key){ Node<V> leftNode = put(node,node.getLeft(), key, value); node.setLeft(leftNode); }else if(node.getKey() < key){ Node<V> rightNode = put(node,node.getRight(), key, value); node.setRight(rightNode); }else{ node.setValue(value); } node.setSize(size(node.getLeft())+size(node.getRight())+1); return node; } /** * 删除结点 * @param node */ private void delete(Node<V> node){ //有2个子结点时 if(node.getLeft() != null && node.getRight() != null) { //获取后继结点 Node<V> successorNode = getSuccessorNode(node); reduceNodeSize(successorNode); //当后继结点的父结点为删除的结点时 if(Objects.equals(successorNode.getParentNode().getKey(),node.getKey())){ node.setRight(successorNode.getRight()); }else{ successorNode.getParentNode().setLeft(successorNode.getRight()); if(successorNode.getRight() != null){ successorNode.getRight().setParentNode(successorNode.getParentNode()); } } node.setKey(successorNode.getKey()); node.setValue(successorNode.getValue()); return; } Node<V> temp = null; //只存在一个子结点时 boolean oneNode = (node.getLeft() != null && node.getRight() == null) || (node.getRight() != null && node.getLeft() == null); if(oneNode){ temp = node.getLeft() != null ? node.getLeft() : node.getRight(); temp.setParentNode(node.getParentNode()); if(node.getParentNode() == null){ root = temp; return; } if(node.getKey() > node.getParentNode().getKey()){ node.getParentNode().setRight(temp); } if(node.getKey() < node.getParentNode().getKey()){ node.getParentNode().setLeft(temp); } reduceNodeSize(temp); } //无子结点时 if(node.getLeft() == null && node.getRight() == null){ if(node.getParentNode() == null){ destroyTree(); return; } if(node.getKey() > node.getParentNode().getKey()){ node.getParentNode().setRight(null); } if(node.getKey() < node.getParentNode().getKey()){ node.getParentNode().setLeft(null); } reduceNodeSize(node); } } /** * search successor node * @param node this node * @return successor node */ private Node<V> getSuccessorNode(Node<V> node){ if(node == null){ return node; } Node<V> successorNode = node.getRight(); if(successorNode == null){ return node; } while (successorNode.getLeft() != null){ successorNode = successorNode.getLeft(); } return successorNode; } /** * search min node * @param node this node * @return min node */ private Node<V> getMinNode(Node<V> node){ if(node == null){ return node; } Node<V> minNode = node.getLeft(); if(minNode == null){ return node; } while (minNode.getLeft() != null){ minNode = minNode.getLeft(); } return minNode; } /** * search max node * @param node this node * @return max node */ private Node<V> getMaxNode(Node<V> node){ if(node == null){ return node; } Node<V> maxNode = node.getRight(); if(maxNode == null){ return node; } while (maxNode.getRight() != null){ maxNode = maxNode.getRight(); } return maxNode; } /** * when remove node,need to reduce node size * @param node this node */ private void reduceNodeSize(Node<V> node){ if(node == null || node.getParentNode() == null){ return; } node.getParentNode().setSize(node.getParentNode().getSize()-1); reduceNodeSize(node.getParentNode()); } /** * destroy this tree */ private void destroyTree(){ root = null; } /** * this tree with before foreach * @param node this node * @return */ private List<Integer> beforeForeach(Node<V> node){ List<Integer> result = new ArrayList<>(); if(node == null){ return result; } result.add(node.getKey()); result.addAll(beforeForeach(node.getLeft())); result.addAll(beforeForeach(node.getRight())); return result; } /** * this tree with centre foreach * @param node this node * @return */ private List<Integer> centreForeach(Node<V> node){ List<Integer> result = new ArrayList<>(); if(node == null){ return result; } result.addAll(centreForeach(node.getLeft())); result.add(node.getKey()); result.addAll(centreForeach(node.getRight())); return result; } /** * this tree with later foreach * @param node this node * @return */ private List<Integer> laterForeach(Node<V> node){ List<Integer> result = new ArrayList<>(); if(node == null){ return result; } result.addAll(laterForeach(node.getLeft())); result.addAll(laterForeach(node.getRight())); result.add(node.getKey()); return result; } }
最后是测试类,代码复制后,可直接运行:
public class Test { public static void main(String[] args) { BinTree<String> binTree = new BST<>(); binTree.put(50,"50"); binTree.put(25,"25"); binTree.put(24,"24"); binTree.put(26,"26"); binTree.put(59,"59"); binTree.put(57,"57"); binTree.put(58,"58"); List<Integer> centreForeach = binTree.centreForeach(); System.out.print("中序遍历:"); centreForeach.forEach(each-> System.out.print(each+",")); System.out.println(""); } }