Topic
- Tree
- Depth-first Search
- Breadth-first Search
Description
https://leetcode.com/problems/all-nodes-distance-k-in-binary-tree/
We are given a binary tree (with root node root
), a target
node, and an integer value k
.
Return a list of the values of all nodes that have a distance k
from the target
node. The answer can be returned in any order.
Example 1:
Input: root = [3,5,1,6,2,0,8,null,null,7,4], target = 5, k = 2
Output: [7,4,1]
Explanation:
The nodes that are a distance 2 from the target node (with value 5)
have values 7, 4, and 1.
Note that the inputs "root" and "target" are actually TreeNodes.
The descriptions of the inputs above are just serializations of these objects.
Note:
- The given tree is non-empty.
- Each node in the tree has unique values
0 <= node.val <= 500
. - The
target
node is a node in the tree. 0 <= k <= 1000
.
Analysis
方法一:我写的。
- 找出root到target的路径,用DFS
- 以步骤1的路径的节点为基础,向上用BFS找出符合剩余距离的节点,加入结果集。
- 在target的左右子孙树中,用BFS找出符合距离节点,加入结果集。
方法二:别人写的。
- 用map存储root到target的路径上的各节点到target的距离,键为节点,值为距离值,如路径上节点有(3)->(5),target为(5),则map为{(3)=1, (5)=0}。用到DFS遍历二叉树得出这个map。
- 用DFS求层数方式+用到步骤1的map,很巧妙。Link
方法三:别人写的,方法二的精简缝合版,一时半会难看懂。
Submission
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import com.lun.util.BinaryTree.TreeNode;
public class AllNodesDistanceKInBinaryTree {
//方法一:我写的
public List<Integer> distanceK(TreeNode root, TreeNode target, int k) {
LinkedList<TreeNode> path = findTarget(root, target);
List<Integer> result = new ArrayList<>();
findUp(k, path, result);
findDown(k, path.peekLast(), result);
return result;
}
public void findUp(int k, LinkedList<TreeNode> path, List<Integer> result) {
int pathTargetIndex = path.size() - 1;
int pathInitCount = 1;
int currentIndex = pathTargetIndex - 1;
int remainder = k - pathInitCount;//剩余步数
for(;currentIndex >= 0 && remainder >= 0; //
pathInitCount++, currentIndex--, remainder = k - pathInitCount) {
TreeNode currentNode = path.get(currentIndex);
if(remainder > 0) {
TreeNode nextNode = path.get(currentIndex + 1);
//在右子树查找
LinkedList<TreeNode> queue = new LinkedList<>();
queue.offer(currentNode);
int count = 0;
while(!queue.isEmpty()) {
for(int size = queue.size(); size > 0; size--) {
TreeNode node = queue.poll();
if(currentNode.left == nextNode) {
//只查右子树
if(node.left != null && node.left != nextNode)
queue.offer(node.left);
if(node.right != null)
queue.offer(node.right);
}else {
if(node.left != null)//刚开始一层有点特殊
queue.offer(node.left);
if(node.right != null && node.right != nextNode)
queue.offer(node.right);
}
}
if(++count == remainder) {
queue.stream().forEach(a->result.add(a.val));
break;
}
}
}else if(remainder == 0) {
result.add(currentNode.val);
}
}
}
public void findDown(int k, TreeNode startNode, List<Integer> result) {
if(k == 0) {//findUp不用,只一个即可
result.add(startNode.val);
return;
}
LinkedList<TreeNode> queue = new LinkedList<>();
queue.offer(startNode);
int count = 0;
while(!queue.isEmpty()) {
for(int size = queue.size(); size > 0; size--) {
TreeNode node = queue.poll();
if(node.left != null)
queue.offer(node.left);
if(node.right != null)
queue.offer(node.right);
}
if(++count == k) {
queue.stream().forEach(a->result.add(a.val));
break;
}
}
}
//目标节点到跟节点的路径
@SuppressWarnings("unchecked")
public LinkedList<TreeNode> findTarget(TreeNode root, TreeNode target){
LinkedList<Object[]> stack = new LinkedList<>();
stack.push(new Object[] {root, new LinkedList<TreeNode>()});
while(!stack.isEmpty()) {
Object[] arr = stack.pop();
TreeNode node = (TreeNode)arr[0];
LinkedList<TreeNode> path = (LinkedList<TreeNode>)arr[1];
path.add(node);
if(node.val == target.val)
return path;
if(node.left != null && node.right == null) {
stack.push(new Object[] {node.left, path});
}
if(node.left == null && node.right != null) {
stack.push(new Object[] {node.right, path});
}
if(node.left != null && node.right != null) {
stack.push(new Object[] {node.right, new LinkedList<TreeNode>(path)});
stack.push(new Object[] {node.left, path});
}
}
return new LinkedList<TreeNode>();
}
//方法二:别人写的
public static class Solution1{
Map<TreeNode, Integer> map = new HashMap<>();
public List<Integer> distanceK(TreeNode root, TreeNode target, int K) {
List<Integer> res = new LinkedList<>();
find(root, target);
dfs(root, K, map.get(root), res);
return res;
}
// find target node first and store the distance in that path that we could use it later directly
private int find(TreeNode root, TreeNode target) {
if (root == null) return -1;
if (root.val == target.val) {
map.put(root, 0);
return 0;
}
int left = find(root.left, target);
if (left >= 0) {
map.put(root, left + 1);
return left + 1;
}
int right = find(root.right, target);
if (right >= 0) {
map.put(root, right + 1);
return right + 1;
}
return -1;
}
private void dfs(TreeNode root, int K, int length, List<Integer> res) {
if (root == null) return;
if (map.containsKey(root)) length = map.get(root);
if (length == K) res.add(root.val);
dfs(root.left, K, length + 1, res);
dfs(root.right, K, length + 1, res);
}
}
//方法三:别人写的2,方法二缝合版
public static class Solution2{
public List<Integer> distanceK(TreeNode root, TreeNode target, int K) {
List<Integer> res = new LinkedList<>();
if (K == 0) {
res.add(target.val);
} else {
dfs(res, root, target.val, K ,0);
}
return res;
}
private int dfs(List<Integer> res, TreeNode node, int target, int K, int depth) {
if (node == null) return 0;
if (depth == K) {
res.add(node.val);
return 0;
}
int left, right;
if (node.val == target || depth > 0) {
left = dfs(res, node.left, target, K, depth + 1);
right = dfs(res, node.right, target, K, depth + 1);
} else {
left = dfs(res, node.left, target, K, depth);
right = dfs(res, node.right, target, K, depth);
}
if (node.val == target) return 1;
if (left == K || right == K) {
res.add(node.val);
return 0;
}
if (left > 0) {
dfs(res, node.right, target, K, left + 1);
return left + 1;
}
if (right > 0) {
dfs(res, node.left, target, K, right + 1);
return right + 1;
}
return 0;
}
}
}
Test
import static org.junit.Assert.*;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import org.hamcrest.collection.IsEmptyCollection;
import org.hamcrest.collection.IsIterableContainingInAnyOrder;
import org.junit.Test;
import com.lun.medium.AllNodesDistanceKInBinaryTree.Solution1;
import com.lun.medium.AllNodesDistanceKInBinaryTree.Solution2;
import com.lun.util.BinaryTree;
import com.lun.util.BinaryTree.TreeNode;
public class AllNodesDistanceKInBinaryTreeTest {
@Test
public void test() {
AllNodesDistanceKInBinaryTree obj = new AllNodesDistanceKInBinaryTree();
TreeNode root = BinaryTree.integers2BinaryTree(3, 5, 1, 6, 2, 0, 8, null, null, 7, 4);
assertThat(obj.distanceK(root, new TreeNode(5), 2), IsIterableContainingInAnyOrder.containsInAnyOrder(7,4,1));
}
@Test
public void test2() {
Solution1 obj = new Solution1();
TreeNode root = BinaryTree.integers2BinaryTree(3, 5, 1, 6, 2, 0, 8, null, null, 7, 4);
assertThat(obj.distanceK(root, new TreeNode(5), 2), //
IsIterableContainingInAnyOrder.containsInAnyOrder(7,4,1));
}
@Test
public void test3() {
Solution2 obj = new Solution2();
TreeNode root = BinaryTree.integers2BinaryTree(3, 5, 1, 6, 2, 0, 8, null, null, 7, 4);
assertThat(obj.distanceK(root, new TreeNode(5), 2), //
IsIterableContainingInAnyOrder.containsInAnyOrder(7,4,1));
}
@Test
public void testFind() {
AllNodesDistanceKInBinaryTree obj = new AllNodesDistanceKInBinaryTree();
TreeNode root = BinaryTree.integers2BinaryTree(3, 5, 1, 6, 2, 0, 8, null, null, 7, 4);
StringBuilder sb = new StringBuilder();
obj.findTarget(root, new TreeNode(4)).forEach(a->sb.append(a.val+","));
assertEquals("3,5,2,4,", sb.toString());
StringBuilder sb2 = new StringBuilder();
obj.findTarget(root, new TreeNode(8)).forEach(a->sb2.append(a.val+","));
assertEquals("3,1,8,", sb2.toString());
StringBuilder sb3 = new StringBuilder();
obj.findTarget(root, new TreeNode(5)).forEach(a->sb3.append(a.val+","));
assertEquals("3,5,", sb3.toString());
assertThat(obj.findTarget(root, new TreeNode(9)), IsEmptyCollection.empty());
StringBuilder sb4 = new StringBuilder();
obj.findTarget(root, new TreeNode(3)).forEach(a->sb4.append(a.val+","));
assertEquals("3,", sb4.toString());
}
@Test
public void testFindUp() {
AllNodesDistanceKInBinaryTree obj = new AllNodesDistanceKInBinaryTree();
TreeNode root = BinaryTree.integers2BinaryTree(3, 5, 1, 6, 2, 0, 8, null, null, 7, 4);
LinkedList<TreeNode> path = obj.findTarget(root, new TreeNode(5));
List<Integer> result = new ArrayList<>();
obj.findUp(3, path, result);
assertThat(result, IsIterableContainingInAnyOrder.containsInAnyOrder(0, 8));
List<Integer> result2 = new ArrayList<>();
obj.findUp(2, path, result2);
assertThat(result2, IsIterableContainingInAnyOrder.containsInAnyOrder(1));
}
@Test
public void testFindDown() {
AllNodesDistanceKInBinaryTree obj = new AllNodesDistanceKInBinaryTree();
TreeNode root = BinaryTree.integers2BinaryTree(3, 5, 1, 6, 2, 0, 8, null, null, 7, 4);
LinkedList<TreeNode> path = obj.findTarget(root, new TreeNode(5));
List<Integer> result = new ArrayList<>();
obj.findDown(2, path.peekLast(), result);
assertThat(result, IsIterableContainingInAnyOrder.containsInAnyOrder(7, 4));
List<Integer> result2 = new ArrayList<>();
obj.findDown(1, path.peekLast(), result2);
assertThat(result2, IsIterableContainingInAnyOrder.containsInAnyOrder(6, 2));
}
}