题目:
Given a non-empty binary search tree and a target value, find k values in the BST that are closest to the target.
Note:
- Given target value is a floating point.
- You may assume k is always valid, that is: k ≤ total nodes.
- You are guaranteed to have only one unique set of k values in the BST that are closest to the target.
Follow up:
Assume that the BST is balanced, could you solve it in less than O(n) runtime (where n = total nodes)?
思路:
1、中序遍历法:利用二叉搜索树的中序遍历,如果结果集合中元素还不到k个,就把当前元素加到集合中去;如果集合中的元素已经到达k个了,那么有如下情况:1)target的值比集合中最小的值小,因为中序遍历是有序的,最小元素就是第0个元素,所以此时可以直接返回;2)target的值大于集合中的最小值,此时我们就比较当前值和集合中的最小值,如果当前结点值比那个值更靠近target,那么我们就用当前元素替换最小值。直到我们无法再找到可以替换的元素,就可以返回结果了。该算法的时间复杂度是O(n)。
2、双栈法:这是一种时间复杂度更低的算法,利用两个栈来保存target的前驱和后继,而且栈顶元素保存的是距离target最近的前驱和后继,这样就可以每次取到距离target最近的值。这种算法的时间复杂度是O(klogn)。
代码:
1、中序遍历法:
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode(int x) : val(x), left(NULL), right(NULL) {}
* };
*/
class Solution {
public:
vector<int> closestKValues(TreeNode* root, double target, int k) {
if (!root) {
return ret;
}
closestKValues(root->left, target, k);
if (ret.size() < k) {
ret.push_back(root->val);
}
else if (fabs(target - ret[index]) > fabs(target - root->val)) {
ret[index++] = root->val;
if (index == k) {
index = 0;
}
}
else {
return ret;
}
closestKValues(root->right, target, k);
return ret;
}
private:
vector<int> ret;
int index = 0;
};
2、双栈法:
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode(int x) : val(x), left(NULL), right(NULL) {}
* };
*/
class Solution {
public:
vector<int> closestKValues(TreeNode* root, double target, int k) {
if(!root) {
return {};
}
initStack(root, target);
while(k-- > 0) {
if(pre.empty() || (!suc.empty() && fabs(pre.top()->val-target) > fabs(suc.top()->val-target))) {
ans.push_back(suc.top()->val); // the cloeset one is from suc.top();
getSuc();
}
else {
ans.push_back(pre.top()->val); // the cloeset one is from pre.top();
getPre();
}
}
return ans;
}
private:
stack<TreeNode*> pre, suc;
vector<int> ans;
void initStack(TreeNode* root, double target) { // the time complexity is O(logn)
while(root) {
if(root->val <= target) {
pre.push(root); // in pre, the values are smaller or equal than target
root = root->right;
}
else {
suc.push(root); // in suc, the values are larger than target
root = root->left;
}
}
}
void getPre() // the time complexity is O(logn)
{
auto node = pre.top();
pre.pop();
if(node->left) {
pre.push(node = node->left);
while(node->right) {
pre.push(node = node->right);
}
}
}
void getSuc() { // the time complexity is O(logn)
auto node = suc.top();
suc.pop();
if(node->right) {
suc.push(node = node->right);
while(node->left) {
suc.push(node = node->left);
}
}
}
};