节点定义
import java.util.List;
public class TreeNode<E> {
//数据
public E key;
//父节点
public TreeNode<E> parent;
//所有子节点
public List<TreeNode<E>> children;
public TreeNode(E key, TreeNode<E> parent) {
this.key = key;
this.parent = parent;
}
public TreeNode(E key) {
this.key = key;
}
@Override
public String toString() {
return "Node[" + "key=" + key + ']';
}
}
接口定义:
public interface ITree<E> {
//获取节点数
int getSize();
//获取根节点
TreeNode<E> getRoot();
//获取x的父节点
TreeNode<E> getParent(TreeNode<E> x);
//获取x的第一个儿子
TreeNode<E> getFirstChild(TreeNode<E> x);
//获取x的下一个兄弟
TreeNode<E> getNextSibling(TreeNode<E> x);
//子树高度
int getHeight(TreeNode<E> x);
//插入子节点
void insertChild(TreeNode<E> x, TreeNode<E> child);
//删除第i个子节点
void deleteChild(TreeNode<E> x, int i);
}
实现类:
import java.util.ArrayList;
import java.util.List;
public class MyTree<E> implements ITree<E> {
private int size = 0;
private TreeNode root;
public MyTree() {}
public MyTree(TreeNode root) {
this.root = root;
++size;
}
@Override
public int getSize() {
return size;
}
@Override
public TreeNode<E> getRoot() {
return root;
}
@Override
public TreeNode<E> getParent(TreeNode<E> x) {
return x.parent;
}
@Override
public TreeNode<E> getFirstChild(TreeNode<E> x) {
return x.children.get(0);
}
@Override
public TreeNode<E> getNextSibling(TreeNode<E> x) {
List<TreeNode<E>> children = x.parent.children;
int i = children.indexOf(x);
try {
return children.get(i + 1);
} catch (Exception e) {
return null;
}
}
@Override
public int getHeight(TreeNode<E> x) {
if (x.children == null) {
return 0;
}
int h = 0;
for (int i = 0; i < x.children.size(); ++i) {
h = Math.max(h, getHeight(x.children.get(i)));
}
return h + 1;
}
@Override
public void insertChild(TreeNode<E> p, TreeNode<E> child) {
if (p.children == null) {
p.children = new ArrayList<>();
}
p.children.add(child);
child.parent = p;
++size;
}
@Override
public void deleteChild(TreeNode<E> p, int i) {
//删除一个节点时,如果不重新建立父子节点的联系,那么删除一个节点相当于把它和它的所有的子孙节点都删除
size -= getChildSize(p.children.get(i));
p.children.remove(i);
}
//获得当前节点及其所有子孙节点的个数总和
public int getChildSize(TreeNode<E> x) {
if (x.children == null) {
return 1;
}
int count = 0;
for (int i = 0; i < x.children.size(); ++i) {
count += getChildSize(x.children.get(i));
}
return count + 1;
}
}
测试类:
import org.junit.Before;
import org.junit.jupiter.api.Test;
class MyTreeTest {
MyTree<String> tree = new MyTree<>(new TreeNode("a"));
@Test
public void testGetHeight() {
TreeNode<String> root = tree.getRoot();
TreeNode<String> b = new TreeNode<>("b");
tree.insertChild(root, b);
TreeNode<String> c = new TreeNode<>("c");
tree.insertChild(root, c);
TreeNode<String> d = new TreeNode<>("d");
tree.insertChild(root, d);
TreeNode<String> e = new TreeNode<>("e");
tree.insertChild(b, e);
tree.insertChild(b, new TreeNode<>("f"));
tree.insertChild(c, new TreeNode<>("g"));
tree.insertChild(d, new TreeNode<>("h"));
TreeNode<String> i = new TreeNode<>("i");
tree.insertChild(e, i);
tree.insertChild(i, new TreeNode<>("j"));
System.out.println("height:" + tree.getHeight(tree.getRoot()));
System.out.println("size:" + tree.getSize());
tree.deleteChild(tree.getRoot(), 0);
System.out.println("height:" + tree.getHeight(tree.getRoot()));
System.out.println("size:" + tree.getSize());
}
}