LeetCode - Medium - 863. All Nodes Distance K in Binary Tree

Topic

  • Tree
  • Depth-first Search
  • Breadth-first Search

Description

https://leetcode.com/problems/all-nodes-distance-k-in-binary-tree/

We are given a binary tree (with root node root), a target node, and an integer value k.

Return a list of the values of all nodes that have a distance k from the target node. The answer can be returned in any order.

Example 1:

Input: root = [3,5,1,6,2,0,8,null,null,7,4], target = 5, k = 2

Output: [7,4,1]

Explanation: 
The nodes that are a distance 2 from the target node (with value 5)
have values 7, 4, and 1.

Note that the inputs "root" and "target" are actually TreeNodes.
The descriptions of the inputs above are just serializations of these objects.

Note:

  1. The given tree is non-empty.
  2. Each node in the tree has unique values 0 <= node.val <= 500.
  3. The target node is a node in the tree.
  4. 0 <= k <= 1000.

Analysis

方法一:我写的。

  1. 找出root到target的路径,用DFS
  2. 以步骤1的路径的节点为基础,向上用BFS找出符合剩余距离的节点,加入结果集。
  3. 在target的左右子孙树中,用BFS找出符合距离节点,加入结果集。

方法二:别人写的。

  1. 用map存储root到target的路径上的各节点到target的距离,键为节点,值为距离值,如路径上节点有(3)->(5),target为(5),则map为{(3)=1, (5)=0}。用到DFS遍历二叉树得出这个map。
  2. 用DFS求层数方式+用到步骤1的map,很巧妙。Link

方法三:别人写的,方法二的精简缝合版,一时半会难看懂。

Submission

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

import com.lun.util.BinaryTree.TreeNode;

public class AllNodesDistanceKInBinaryTree {
	
	//方法一:我写的
    public List<Integer> distanceK(TreeNode root, TreeNode target, int k) {
    	LinkedList<TreeNode> path = findTarget(root, target);
    	List<Integer> result = new ArrayList<>();
    	findUp(k, path, result);
    	findDown(k, path.peekLast(), result);
    	return result;
    }

	public void findUp(int k, LinkedList<TreeNode> path, List<Integer> result) {
		int pathTargetIndex = path.size() - 1;
    	int pathInitCount = 1;
    	int currentIndex = pathTargetIndex - 1;
    	int remainder = k - pathInitCount;//剩余步数
    	for(;currentIndex >= 0 && remainder >= 0; // 
    			pathInitCount++, currentIndex--, remainder = k - pathInitCount) {

    		TreeNode currentNode = path.get(currentIndex);
    		if(remainder > 0) {
    			TreeNode nextNode = path.get(currentIndex + 1);
    			
    			//在右子树查找
    			LinkedList<TreeNode> queue = new LinkedList<>();
    			queue.offer(currentNode);
    			int count = 0;
    	    	while(!queue.isEmpty()) {
    	    		
    	    		for(int size = queue.size(); size > 0; size--) {
    	    			TreeNode node = queue.poll();
    	    			
    	    			if(currentNode.left == nextNode) {
    	    				//只查右子树
    	    				
        	    			if(node.left != null && node.left != nextNode)
        	    				queue.offer(node.left);
        	    			
        	    			if(node.right != null)
        	    				queue.offer(node.right);
    	    				
    	    			}else {
    	    				
        	    			if(node.left != null)//刚开始一层有点特殊
        	    				queue.offer(node.left);
        	    			
        	    			if(node.right != null && node.right != nextNode)
        	    				queue.offer(node.right);
    	    			}
    	    			
    	    		}
    	    		
    	    		if(++count == remainder) {
    	    			queue.stream().forEach(a->result.add(a.val));
    	    			break;
    	    		}
    	    	}
    		}else if(remainder == 0) {
    			result.add(currentNode.val);
    		}
    	}
	}
	
	public void findDown(int k, TreeNode startNode, List<Integer> result) {
		
		if(k == 0) {//findUp不用,只一个即可
			result.add(startNode.val);
			return;
		}
			
		LinkedList<TreeNode> queue = new LinkedList<>();
    	queue.offer(startNode);
    	int count = 0;
    	while(!queue.isEmpty()) {
    		
    		for(int size = queue.size(); size > 0; size--) {
    			TreeNode node = queue.poll();
    			
    			if(node.left != null)
    				queue.offer(node.left);
    			
    			if(node.right != null)
    				queue.offer(node.right);
    		}
    		
    		if(++count == k) {
    			queue.stream().forEach(a->result.add(a.val));
    			break;
    		}
    	}
	}
	
    
    //目标节点到跟节点的路径
    @SuppressWarnings("unchecked")
	public LinkedList<TreeNode> findTarget(TreeNode root, TreeNode target){
    	
    	LinkedList<Object[]> stack = new LinkedList<>();
    	
    	stack.push(new Object[] {root, new LinkedList<TreeNode>()});
    	
    	while(!stack.isEmpty()) {
    		Object[] arr = stack.pop();
    		TreeNode node = (TreeNode)arr[0];
    		LinkedList<TreeNode> path = (LinkedList<TreeNode>)arr[1];
    	
    		path.add(node);
    		
    		if(node.val == target.val)
    			return path;
    		
    		if(node.left != null && node.right == null) {
    			stack.push(new Object[] {node.left, path});
    		}
    		
    		if(node.left == null && node.right != null) {
    			stack.push(new Object[] {node.right, path});
    		}
    		
    		if(node.left != null && node.right != null) {
    			stack.push(new Object[] {node.right, new LinkedList<TreeNode>(path)});
    			stack.push(new Object[] {node.left, path});
    		}
    	}
    	
    	return new LinkedList<TreeNode>();
    }
    
    //方法二:别人写的
    public static class Solution1{
        Map<TreeNode, Integer> map = new HashMap<>();
        
        public List<Integer> distanceK(TreeNode root, TreeNode target, int K) {
            List<Integer> res = new LinkedList<>();
            find(root, target);
            dfs(root, K, map.get(root), res);
            return res;
        }
        
        // find target node first and store the distance in that path that we could use it later directly
        private int find(TreeNode root, TreeNode target) {
            if (root == null) return -1;
            if (root.val == target.val) {
                map.put(root, 0);
                return 0;
            }
            int left = find(root.left, target);
            if (left >= 0) {
                map.put(root, left + 1);
                return left + 1;
            }
    		int right = find(root.right, target);
    		if (right >= 0) {
                map.put(root, right + 1);
                return right + 1;
            }
            return -1;
        }
        
        private void dfs(TreeNode root, int K, int length, List<Integer> res) {
            if (root == null) return;
            if (map.containsKey(root)) length = map.get(root);
            if (length == K) res.add(root.val);
            dfs(root.left, K, length + 1, res);
            dfs(root.right, K, length + 1, res);
        }
    }
    
    //方法三:别人写的2,方法二缝合版
    public static class Solution2{
    	
        public List<Integer> distanceK(TreeNode root, TreeNode target, int K) {
            List<Integer> res = new LinkedList<>();
            if (K == 0) {
                res.add(target.val);
            } else {
                dfs(res, root, target.val, K ,0);
            }
            return res;
        }
        
        private int dfs(List<Integer> res, TreeNode node, int target, int K, int depth) {
            if (node == null) return 0;
            
            if (depth == K) {
                res.add(node.val);
                return 0;
            }
            
            int left, right;
            if (node.val == target || depth > 0) {
                left = dfs(res, node.left, target, K, depth + 1);
                right = dfs(res, node.right, target, K, depth + 1);
            } else {
                left = dfs(res, node.left, target, K, depth);
                right = dfs(res, node.right, target, K, depth);
            }
            
            if (node.val == target) return 1;
            
            if (left == K || right == K) {
                res.add(node.val);
                return 0;
            }
            
            if (left > 0) {
                dfs(res, node.right, target, K, left + 1);
                return left + 1;
            }
            
            if (right > 0) {
                dfs(res, node.left, target, K, right + 1);
                return right + 1;
            }
            
            return 0;
        }
    }
    
    
}

Test

import static org.junit.Assert.*;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;

import org.hamcrest.collection.IsEmptyCollection;
import org.hamcrest.collection.IsIterableContainingInAnyOrder;
import org.junit.Test;

import com.lun.medium.AllNodesDistanceKInBinaryTree.Solution1;
import com.lun.medium.AllNodesDistanceKInBinaryTree.Solution2;
import com.lun.util.BinaryTree;
import com.lun.util.BinaryTree.TreeNode;

public class AllNodesDistanceKInBinaryTreeTest {
	
	
	@Test
	public void test() {
		AllNodesDistanceKInBinaryTree obj = new AllNodesDistanceKInBinaryTree();
		TreeNode root = BinaryTree.integers2BinaryTree(3, 5, 1, 6, 2, 0, 8, null, null, 7, 4);
		
		assertThat(obj.distanceK(root, new TreeNode(5), 2), IsIterableContainingInAnyOrder.containsInAnyOrder(7,4,1));
	}
	
	@Test
	public void test2() {
		Solution1 obj = new Solution1();
		TreeNode root = BinaryTree.integers2BinaryTree(3, 5, 1, 6, 2, 0, 8, null, null, 7, 4);
		
		assertThat(obj.distanceK(root, new TreeNode(5), 2), // 
				IsIterableContainingInAnyOrder.containsInAnyOrder(7,4,1));
	}
	@Test
	public void test3() {
		Solution2 obj = new Solution2();
		TreeNode root = BinaryTree.integers2BinaryTree(3, 5, 1, 6, 2, 0, 8, null, null, 7, 4);
		
		assertThat(obj.distanceK(root, new TreeNode(5), 2), // 
				IsIterableContainingInAnyOrder.containsInAnyOrder(7,4,1));
	}
	

	@Test
	public void testFind() {
		AllNodesDistanceKInBinaryTree obj = new AllNodesDistanceKInBinaryTree();
		TreeNode root = BinaryTree.integers2BinaryTree(3, 5, 1, 6, 2, 0, 8, null, null, 7, 4);
		
		StringBuilder sb = new StringBuilder();
		obj.findTarget(root, new TreeNode(4)).forEach(a->sb.append(a.val+","));
		assertEquals("3,5,2,4,", sb.toString());
		
		StringBuilder sb2 = new StringBuilder();
		obj.findTarget(root, new TreeNode(8)).forEach(a->sb2.append(a.val+","));
		assertEquals("3,1,8,", sb2.toString());
		
		StringBuilder sb3 = new StringBuilder();
		obj.findTarget(root, new TreeNode(5)).forEach(a->sb3.append(a.val+","));
		assertEquals("3,5,", sb3.toString());
		
		assertThat(obj.findTarget(root, new TreeNode(9)), IsEmptyCollection.empty());
		
		StringBuilder sb4 = new StringBuilder();
		obj.findTarget(root, new TreeNode(3)).forEach(a->sb4.append(a.val+","));
		assertEquals("3,", sb4.toString());
	}
	
	@Test
	public void testFindUp() {
		AllNodesDistanceKInBinaryTree obj = new AllNodesDistanceKInBinaryTree();
		TreeNode root = BinaryTree.integers2BinaryTree(3, 5, 1, 6, 2, 0, 8, null, null, 7, 4);
		
		LinkedList<TreeNode> path = obj.findTarget(root, new TreeNode(5));
		List<Integer> result = new ArrayList<>();
		obj.findUp(3, path, result);
		assertThat(result, IsIterableContainingInAnyOrder.containsInAnyOrder(0, 8));
		
		List<Integer> result2 = new ArrayList<>();
		obj.findUp(2, path, result2);
		assertThat(result2, IsIterableContainingInAnyOrder.containsInAnyOrder(1));
		
	}
	
	@Test
	public void testFindDown() {
		AllNodesDistanceKInBinaryTree obj = new AllNodesDistanceKInBinaryTree();
		TreeNode root = BinaryTree.integers2BinaryTree(3, 5, 1, 6, 2, 0, 8, null, null, 7, 4);
		
		LinkedList<TreeNode> path = obj.findTarget(root, new TreeNode(5));
		
		List<Integer> result = new ArrayList<>();
		obj.findDown(2, path.peekLast(), result);
		assertThat(result, IsIterableContainingInAnyOrder.containsInAnyOrder(7, 4));
		
		List<Integer> result2 = new ArrayList<>();
		obj.findDown(1, path.peekLast(), result2);
		assertThat(result2, IsIterableContainingInAnyOrder.containsInAnyOrder(6, 2));
	}
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值