public class TreeNode {
int val;
TreeNode left;
TreeNode right;
TreeNode(int x) {
val = x;
}
}
package Group2;
import java.util.*;
public class Test3 {
// 用map记录每个节点的父节点
private Map<TreeNode, TreeNode> parents = new HashMap<>();
private Set<TreeNode> used = new HashSet<>();
private TreeNode targetNode;
// 找到目标节点后以目标节点为开始位置向三个方向蔓延
public List<Integer> distanceK(TreeNode root, TreeNode target, int K) {
find(root, null, target);
List<Integer> res = new LinkedList<>();
dfs(targetNode, res, K);
return res;
}
//找出各个节点的父节点
private void find(TreeNode root, TreeNode parent, TreeNode target) {
if (null == root) {
return;
}
//找出目标节点
if (root.val == target.val) {
targetNode = root;
}
parents.put(root, parent);
find(root.left, root, target);
find(root.right, root, target);
}
private void dfs(TreeNode root, List<Integer> collector, int distance) {
if (root != null && !used.contains(root)) {
// 标记为已访问
used.add(root);
if (distance <= 0) {
collector.add(root.val);
return;
}
dfs(root.left, collector, distance - 1);
dfs(root.right, collector, distance - 1);
dfs(parents.get(root), collector, distance - 1);
}
}
}
/* public List<Integer> distanceK(TreeNode root, TreeNode target, int k) {
List<List<Integer>> paths = new ArrayList<>();
List<Integer> path = new ArrayList<>();
//回溯获得树到节点的路径
dfs(root, paths, path);
return null;
}
private void dfs(TreeNode root, List<List<Integer>> paths, List<Integer> path) {
if (root == null) {
paths.add(new ArrayList<>(path));
return;
}
path.add(root.val);
dfs(root.left, paths, path);
if (root.right != null) {
dfs(root.right, paths, path);
}
path.remove(path.size() - 1);
}*/