选择算法(顺序统计量)

选择算法(顺序统计量)

引入

在计算机科学中,选择算法是一种在列表或数组中找到第 k k k个最小数字的算法。这样的数字被称为第 k k k个顺序统计量。

我们可以使用排序算法在 O ( n log ⁡ 2 n ) O(n\log_2n) O(nlog2n)时间内解决这个问题。但还有一种更快的方法可以在 O ( n ) O(n) O(n)的期望时间内解决这个问题。

期望为线性时间O(n)的算法

这个方法和快速排序类似,我们需要将输入的数组进行划分,与快速排序不同的是,快速排序会递归地处理划分完后的两边,而这个方法只处理一边。因此,期望时间为 O ( n ) O(n) O(n)(这取决于划分操作是否能更平分输入数组),我们将这个算法称为选择算法。

划分操作

图片来自算法导论,核心思想是将输入数组分为 4 4 4个部分进行操作。将输入数组的最后一个元素作为 p i v o t pivot pivot

划分操作

然后剩下三部分,第一部分为小于等于 p i v o t pivot pivot,第二部分大于 p i v o t pivot pivot,第三部分为待划分的区域。

这是C++的实现代码:

#include <algorithm>
#include <iterator>

template <typename RandIt>
inline RandIt partition(RandIt first, RandIt last) {
    if (first == last) {
        return first;
    }

    std::advance(last, -1); // pivot

    for (RandIt it = first; it != last; std::advance(it, 1)) {
        if (*it <= *last) {
            std::iter_swap(first, it);
            std::advance(first, 1);
        }
    }

    std::iter_swap(first, last);
    return first;
}

其中 f i r s t first first i t it it这两根指针(迭代器)将输入数组分为 3 3 3部分:

  • 起始位置到 f i r s t first first(左闭右开区间)小于等于 p i v o t pivot pivot
  • [ f i r s t , i t ) [first, it) [first,it)大于 p i v o t pivot pivot
  • [ i t , p i v o t ) [it, pivot) [it,pivot)为待划分部分

选择算法

这是C++的实现代码:

#include <iterator>

template <typename RandIt>
void nth_element(RandIt first, RandIt nth, RandIt last) {
    auto pivot = partition(first, last);
    auto distance = std::distance(pivot, nth);

    if (distance > 0) {
        std::advance(pivot, 1);
        nth_element(pivot, nth, last);
    } else if (distance < 0) {
        nth_element(first, nth, pivot);
    }
}

其中nth为第 n n n个元素(下标从 0 0 0开始), d i s t a n c e distance distance为第 n n n个元素与 p i v o t pivot pivot之间的距离。

  • 如果 d i s t a n c e distance distance等于 0 0 0,那么表明 n t h nth nth等于 p i v o t pivot pivot,因此 n t h nth nth左边的所有元素都小于或等于它,右边的所有元素都大于它,所以 n t h nth nth就是我们要找的第n小的元素。
  • 如果 d i s t a n c e distance distance大于 0 0 0,说明 n t h nth nth p i v o t pivot pivot的右边,此时只需要在区间 [ p i v o t + 1 , l a s t ) [pivot + 1, last) [pivot+1,last)中按照同样的方式寻找即可。
  • d i s t a n c e distance distance小于 0 0 0的情况与上面类似。

测试

简单的测试代码:

#include <algorithm>
#include <chrono>
#include <iostream>
#include <iterator>
#include <random>

int n, k;
int arr[5000005];
int arrCopy[5000005];

template <typename RandIt>
inline RandIt partition(RandIt first, RandIt last) {
    if (first == last) {
        return first;
    }

    std::advance(last, -1); // pivot

    for (RandIt it = first; it != last; std::advance(it, 1)) {
        if (*it <= *last) {
            std::iter_swap(first, it);
            std::advance(first, 1);
        }
    }

    std::iter_swap(first, last);
    return first;
}

template <typename RandIt>
void nth_element(RandIt first, RandIt nth, RandIt last) {
    auto pivot = partition(first, last);
    auto distance = std::distance(pivot, nth);

    if (distance > 0) {
        std::advance(pivot, 1);
        nth_element(pivot, nth, last);
    } else if (distance < 0) {
        nth_element(first, nth, pivot);
    }
}

int main() {
    std::random_device rd;
    std::mt19937 e(rd());
    std::cin >> n >> k;
    std::uniform_int_distribution dt(-n, n);

    for (int i = 0; i < n; i++) {
        arr[i] = dt(e);
        arrCopy[i] = arr[i];
    }

    auto begin = std::chrono::high_resolution_clock::now();
    nth_element(arr, arr + k, arr + n);
    auto time1 = std::chrono::high_resolution_clock::now() - begin;

    begin = std::chrono::high_resolution_clock::now();
    std::nth_element(arrCopy, arrCopy + k, arrCopy + n);
    auto time2 = std::chrono::high_resolution_clock::now() - begin;

    std::cout << arr[k] << "\n";
    std::cout << arrCopy[k] << "\n";
    std::cout << "nthElement: \t\t" << time1.count() << "ns\n";
    std::cout << "std::nth_element: \t" << time2.count() << "ns\n";
}

运行速度比C++标准库的std::nth_element速度稍微快一点:

测试

参考

《算法导论》

相关推荐
©️2020 CSDN 皮肤主题: 精致技术 设计师:CSDN官方博客 返回首页