tarjan算法_[力扣215] 快速选择算法&划分树

7725709decccd8b04dc5bae8e83a037b.png

题目链接

215. 数组中的第K个最大元素

题目描述

在未排序的数组中找到第k个最大的元素。请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。

样例

示例 1:
输入: [3,2,1,5,6,4] 和 k = 2
输出: 5
示例 2:
输入: [3,2,3,1,2,4,5,5,6] 和 k = 4
输出: 4

topK问题有两种主流做法

  1. 减治(quick selection 算法)

算法1:堆

开一个小顶堆,用于维护枚举到 nums[i] 时的前 k 个最大元素

// 直接写 priority_queue<int> pq; 是大顶堆,小顶堆是下面的写法
priority_queue<int, vector<int>, greater<int> > pq;

先把前 k 个数压进去。然后枚举剩余的数,只要堆顶小于当前枚举的数,就先弹出堆顶在压入当前值,保持堆里是前 K 个最大元素。枚举结束后,堆顶就是答案。

时间复杂度稳定 O(NlogK)

代码(c++)

class Solution {
public:
    int findKthLargest(vector<int>& nums, int k) {
        int n = nums.size();
        if(n == 1) return nums[0];
        priority_queue<int, vector<int>, greater<int> > pq;
        for(int i = 0; i < k; ++i)
            pq.push(nums[i]);
        for(int i = k; i < n; ++i)
            if(pq.top() < nums[i])
            {
                pq.pop();
                pq.push(nums[i]);
            }
        return pq.top();
    }
};

算法2:快速选择算法(减治)

在子区间 [left, right] 中选择第 k 大的数时,完全照搬快排的划分算法:

  1. 选择一个枢轴(pivot),然后交换到 left 位置
int randIdx = rand() % (right - left + 1) + left; // 随机选择 pivot
swap(nums[randIdx], nums[left]);

2. 使用划分算法确定 pivot 的位置。大于 pivot 的元素移到左边,小于等于 pivot 的元素移到右边。

int pivot = nums[left];
int l = left, r = right;
while(l < r)
{
    //...
}
// 一轮 partition 完成

一轮 partition 完成后,pivot 的位置为 l,此时考察 l 和 left + K 的关系,[left, right] 中第 K 个位置的下标是 left + K - 1:

l = left + K - 1: pivot 刚好在 [left, right] 第 k 个位置,找到答案了
l > left + K - 1: [left, l] 中有 l - left + 1 个数字,还要在 [l + 1, right] 中找 K - (l - left + 1) 个
l < left + K - 1: 在 [left, l - 1] 中继续找第 k 大

c6ad6baf15c9c0b8a50dcd143cbe34d9.png

时间复杂度 平均O(N),最坏

,但是最坏情况太难达到了,一般还是认为
快速选择算法是 O(N) 的。

快速选择算法有一个优化:BFPRT算法,也叫中位数的中位数算法,它进一步优化了 pivot 的选取方法,使得最坏时间复杂度也变为

,它由Blum、Floyd、Pratt、Rivest、Tarjan提出。

快速选择算法每完成一轮 partition 将数据分成不均匀的两份后,只有其中一份对结果有影响,可以去掉一部分,数据规模就减少一部分,所以也叫减治算法。插入排序,DFS, BFS, 拓扑排序也隐含了减治的思想:每完成一轮,下一轮就可以少考虑一个数据。

代码(c++)

class Solution {
public:
    int findKthLargest(vector<int>& nums, int k) {
        int n = nums.size();
        if(n == 1) return nums[0];
        return partition(nums, k, 0, n - 1);
    }

private:
    int partition(vector<int>& nums, int k, int left, int right)
    {
        // 在 nums 的 [left .. right] 中找第 k 大
        int randIdx = rand() % (right - left + 1) + left; // 随机选择 pivot
        swap(nums[randIdx], nums[left]);
        int pivot = nums[left];
        int l = left, r = right;
        while(l < r)
        {
            while(l < r && nums[r] <= pivot)
                --r;
            if(l < r)
            {
                nums[l] = nums[r];
                ++l;
            }
            while(l < r && nums[l] > pivot)
                ++l;
            if(l < r)
            {
                nums[r] = nums[l];
                --r;
            }
        }
        nums[l] = pivot;
        if(l == left + k - 1)
            return nums[l];
        else if(l < left + k - 1)
            return partition(nums, k - (l - left + 1), l + 1, right);
        else
            return partition(nums, k, left, l - 1);
    }
};

算法3: 归并树

直接套文章 [力扣315] 索引数组&CDQ分治&归并树 里的归并树的思路:已经根据原始数组建好归并树之后,可以

得到区间 [i, j] 中大于 x 的数的个数

在归并树建树过程中可以顺便得到数组的最大值 right 和最小值 left,答案肯定在 [left, right]。

值域二分:每次从答案范围 [left, right] 中猜一个答案 mid = (left + right) / 2; 然后查询区间 [left, right] 中大于 mid 的个数 cnt

cnt = k - 1: mid 是答案,返回 mid
cnt > k - 1: mid 猜小了,left = mid + 1 继续猜
cnt < k - 1: mid + 1 肯定大了, right = mid 继续猜

时间复杂度

,明显弱于快速选择算法的 O(N)。但如果是 M 次查询不同的区间,则为
,其中前半部分是1次建树,后半部分是 M 次查询。这比快速选择算法的O(MN)好。

代码(c++)

struct MTNode
{
    int start, end;
    vector<int> data;
    MTNode *left, *right;
    MTNode(int s, int e, const vector<int>& nums, MTNode* l=nullptr, MTNode* r=nullptr)
        :start(s),end(e),data(nums),left(l),right(r) {}
    ~MTNode()
    {
        if(left)
        {
            delete left;
            left = nullptr;
        }
        if(right)
        {
            delete right;
            right = nullptr;
        }
    }
};

class MergeTree
{
public:
    MergeTree()
    {
        root = nullptr;
    }

    ~MergeTree()
    {
        if(root)
        {
            delete root;
            root = nullptr;
        }
    }

    void build(int start, int end, const vector<int>& nums)
    {
        root = _build(start, end, nums);
    }

    int query(int i, int j, int k)
    {
        if(i > j) return 0;
        int result = 0;
        _query(root, i, j, k, result);
        return result;
    }

    int get(int i)
    {
        return (root -> data)[i];
    }

private:
    MTNode *root;

    void _query(MTNode* root, int i, int j, int k, int& result)
    {
        if(root -> start == i && root -> end == j)
        {
            auto pos = upper_bound((root -> data).begin(), (root -> data).end(), k);
            result += (root -> data).end() - pos;
            return;
        }
        int mid = root -> start + (root -> end - root -> start) / 2;
        if(j <= mid)
        {
            _query(root -> left, i, j, k, result);
            return;
        }
        if(i > mid)
        {
            _query(root -> right, i, j, k, result);
            return;
        }
        _query(root -> left, i, mid, k, result);
        _query(root -> right, mid + 1, j, k, result);
    }

    MTNode* _build(int start, int end, const vector<int>& nums)
    {
        if(start == end)
        {
            return new MTNode(start, end, vector<int>({nums[start]}));
        }
        int mid = start + (end - start) / 2;
        MTNode *left = _build(start, mid, nums);
        MTNode *right = _build(mid + 1, end, nums);
        vector<int> merged((left -> data).size() + (right -> data).size());
        merge((left -> data).begin(), (left -> data).end(), (right -> data).begin(), (right -> data).end(), merged.begin());
        MTNode *cur = new MTNode(start, end, merged, left, right);
        return cur;
    }
};

class Solution_3 {
public:
    int findKthLargest(vector<int>& nums, int k) {
        int n = nums.size();
        if(n == 1) return nums[0];
        MergeTree mergetree;
        mergetree.build(0, n - 1, nums);
        int left = mergetree.get(0), right = mergetree.get(n - 1);
        while(left < right)
        {
            int mid = left + (right - left) / 2;
            int cnt = mergetree.query(0, n - 1, mid); // > mid 的元素个数
            if(cnt == k - 1)
                return mid;
            else if(cnt > k - 1)
                left = mid + 1;
            else
                right = mid;
        }
        return left;
    }
};

引申: 区间第k大 & 划分树

划分树和归并树都是用线段树作为辅助的,其中划分树的建树是模拟快排过程,自顶向下建树,归并树是模拟归并排序过程,自底向上建树,两种方法含有分治思想。

划分树的基本思想就是对于某个节点对应的区间 [start, end],把它划分成两个子区间,左边区间的数小于等于右边区间的数,左子区间对应划分后的 [start, mid],右子区间对应划分后的 [mid + 1, end],其中 mid = (left + right) / 2。查找的时候通过记录进入左子树的数的个数,确定下一个查找区间,直到区间长度变成1(start = end, 叶子节点),就找到了。

划分树的节点定义

start, end:节点对应的区间
nums[0..end-start]:节点持有的数据,一共 end - start + 1 个,是原始数组排好序之后在 [start, end] 范围的数字,但这里 nums 并没按排好序的顺序存放,而是按原数组的顺序放的。它具体含了哪些数字来自父节点划分的结果
toleft[0..end-start+2]:记录进入左子树的数字的个数
toleft[i] := [0 .. i - 1] 这 i 个数中,进入左子树的数字的个数,有前缀和的思想
toleft[0] = 0
struct PTNode
{
    int start, end;
    vector<int> nums, toleft;
    PTNode *left, *right;
    PTNode(int s, int e, PTNode* l=nullptr, PTNode* r=nullptr)
        :start(s),end(e),nums(vector<int>(end - start + 1)),toleft(vector<int>(end - start + 2)),left(l),right(r) {}
    ~PTNode()
    {
        if(left)
        {
            delete left;
            left = nullptr;
        }
        if(right)
        {
            delete right;
            right = nullptr;
        }
    }
};

划分树的建树

过程基本就是模拟快排过程,对于区间 [start, end] 取一个已经排过序的区间中位数,然后把小于中值的点放左边,大于的放右边,等于中位数的需要单独统计,使得进入左子树的数字个数为 [mid - left + 1],确保树是平衡的。划分树建树之前先求取并持有原数组排序后的数组 sorted,节点的中位数直接就取 sorted[mid]。

int median =  sorted[mid];

只要当前节点不是叶子节点(start = end),就先建好两个子节点

PTNode *left = new PTNode(root -> start, mid);
PTNode *right = new PTNode(mid + 1, root -> end);

然后执行划分的算法流程,顺序枚举当前节点的所有数据,判断应该划分的方向,往对应的子节点塞进去。如果是塞进了左子树,toleft 加一。

8bdb7a2c8017420223409be4152941e4.png
[5, 2, 6, 1] 的建树过程

划分树的查询

query(i, j, k) 查询区间 [i, j] 中第 k 大的数

首先确定 [start ..i - 1], [start..j] 中去往左子树的数字个数 tli, tlj。其中 start = i 的情况需要特判。tlj - tli 是 [i..j] 中去往左子树的元素个数,记 tl。

int tli = 0;
if(root -> start != i)
    tli = (root -> toleft)[i - root -> start];
int tlj = (root -> toleft)[j - root -> start + 1];
int tl = tlj - tli;

看 k 和 tl 的关系决定子查询,类似于快速选择算法中考察 l 和 left + K 的关系。

tl >= k: 第 k 大在左子树,答案是 query(root -> left, new_i, new_j, k)
tl < k: 第 k 大在右子树,答案是 query(root -> right, new_i, new_j, k - tl)
new_i, new_j 需要画图推导

49790dfdff3c098603856516e5cefc61.png

用划分树 AC 本题的代码

但是跑的效果很差,因为查询只有1次,需要查询次数较多时才能体现划分树的优势

查询次数为M时,时间复杂度 O(NlogN + MlogN),比M次快速选择的 O(MN) 和归并树的

都好。
struct PTNode
{
    int start, end;
    vector<int> nums, toleft;
    PTNode *left, *right;
    PTNode(int s, int e, PTNode* l=nullptr, PTNode* r=nullptr)
        :start(s),end(e),nums(vector<int>(end - start + 1)),toleft(vector<int>(end - start + 2)),left(l),right(r) {}
    ~PTNode()
    {
        if(left)
        {
            delete left;
            left = nullptr;
        }
        if(right)
        {
            delete right;
            right = nullptr;
        }
    }
};

class PartitionTree
{
public:
    PartitionTree()
    {
        root = nullptr;
        sorted = vector<int>();
    }

    ~PartitionTree()
    {
        if(root)
        {
            delete root;
            root = nullptr;
        }
    }

    void build(int start, int end, const vector<int>& nums)
    {
        sorted = nums;
        sort(sorted.begin(), sorted.end());
        root = new PTNode(start, end);
        root -> nums = nums;
        _build(root);
    }

    int query(int i, int j, int k)
    {
        if(i > j) return 0;
        return _query(root, i, j, k);
    }

private:
    PTNode *root;
    vector<int> sorted;

    int _query(PTNode* root, int i, int j, int k)
    {
        if(root -> start == root -> end) return (root -> nums)[0];

        int tli = 0;
        if(root -> start != i)
            tli = (root -> toleft)[i - root -> start];
        int tlj = (root -> toleft)[j - root -> start + 1];
        int tl = tlj - tli;
        int new_i, new_j;
        if(tl >= k)
        {
            // 第 k 大在左子
            new_i = root -> start + tli;
            new_j = new_i + tl - 1;
            return _query(root -> left, new_i, new_j, k);
        }
        else
        {
            // 第 k 大在右子
            int mid = root -> start + (root -> end - root -> start) / 2;
            new_i = mid + 1 + i - (root -> start) - tli;
            new_j = new_i + j - i - tl;
            return _query(root -> right, new_i, new_j, k - tl);
        }
    }

    void _build(PTNode* root)
    {
        if(root -> start == root -> end)
            return;
        int mid = root -> start + (root -> end - root -> start) / 2;
        int median =  sorted[mid];
        PTNode *left = new PTNode(root -> start, mid);
        PTNode *right = new PTNode(mid + 1, root -> end);
        int n = (root -> nums).size();
        int median_to_left = mid - root -> start + 1;
        for(int i = 0; i < n; ++i)
        {
            if((root -> nums)[i] < median)
                --median_to_left;
        }
        // 出循环后 median_to_left 为去往左子树中等于中位数的个数
        int to_left = 0; // 去往左子树的个数
        int idx_left = 0, idx_right = 0;
        for(int i = 0; i < n; ++i)
        {
            int cur = (root -> nums)[i];
            if(cur < median || ((cur == median) && median_to_left > 0))
            {
                (left -> nums)[idx_left] = cur;
                ++idx_left;
                ++to_left;
                if(cur == median)
                    --median_to_left;
            }
            else
            {
                (right -> nums)[idx_right] = cur;
                ++idx_right;
            }
            (root -> toleft)[i + 1] = to_left;
        }
        _build(left);
        _build(right);
        root -> left = left;
        root -> right = right;
    }
};

class Solution_4 {
public:
    int findKthLargest(vector<int>& nums, int k) {
        int n = nums.size();
        if(n == 1) return nums[0];
        PartitionTree partitiontree;
        partitiontree.build(0, n - 1, nums);
        return partitiontree.query(0, n - 1, n + 1 - k);
    }
};
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值