最小的前k个数

最近在hadoop上写了一个分布式的KNN,其实本地的KNN也是需要用到的一个东西——从N个距离中选出K个最小的,从而投票得出预测的类别。

这个问题看似很简单,但要做好,还是有难度的。

方案1

n个数字排序,然后截取前k个作为结果。
时间复杂度:O(nlog n)
空间复杂度:O(n)
时间复杂度还好,但是空间复杂度是关于输入规模线性的就不好了,因为当需要预测的数据很多,多到单机内存放不下时,就得考虑外排序了,这样设计起来,整个系统变复杂很多了,不是一个很好的设计。

方案2

使用插入排序维护一个大小为K+1的数组,比如K=3,然后依次插入5,3,1,4,6,0,7,那么数组的变化情况如下:

下面的[x]表示“未赋值”:
(1) 5 [x] [x] [x]
(2) 3 5 [x] [x]
(3) 1 3 5 [x]
(4) 1 3 4 [5](目前最大的5溢出了,我们只关注前k=3个数字)
(5) 1 3 4 [6](很明显,新插入的6也处于“溢出位”)
(6) 0 1 3 [4](4被挤到了“溢出位”)
(6) 0 1 3 [7](注意“溢出位”是被后续的数据覆盖掉的)
所以最终的输出为前k=3个最小的数字:0, 1, 3

我把它封装成一个C++的模板类,这么简单的算法,相信理解的人很容易将其转换成别的语言。
复杂度分析:
1. 时间复杂度:O(KN),当N > 2^K时,很明显这个方案是很有优势的,而一般K很小,这个条件比较容易满足。
2. 空间复杂度:O(K),注意K在KNN算法里是固定大小的,不随输入数据的规模改变

最后,注意我为什么说数组的大小是K+1而不是K,这是因为编程实现的时候比较方便,不用特判去处理边界。

// TopKMinElem.h
#ifndef __TOK_K_MIN_ELEM__H
#define __TOK_K_MIN_ELEM__H

#include <vector>
using namespace std;


template<typename KeyType, typename ElemType>
class TopKMinElem
{
private:
    int index, capacity;
    vector<pair<KeyType, ElemType> > data;

public:
    TopKMinElem(int k)
    : index(0), capacity(k), data(k+1) {}


    void insert(KeyType k, ElemType e) {
        int pos = index - 1;
        while (pos >= 0 && k < data[pos].first) {
            data[pos+1] = data[pos];
            --pos;
        }
        if (pos+1 < capacity)
            data[pos+1] = make_pair(k, e);
        index = min(index+1, capacity);
    }

    vector<pair<KeyType, ElemType> > getTopK() const {
        return vector<pair<KeyType, ElemType> >(data.begin(), data.begin()+index);
    }
};

#endif

然后简单测试代码为:

// main.cpp
#include <stdio.h>
#include <vector>
#include <string>

#include "TopKMinElem.h"

using namespace std;

void show(const vector<pair<int, string> >& v) {
    printf("\n---------\n");
    for (int i = 0; i < v.size(); ++i)
        printf("[%d] (%d, %s)\n", i, v[i].first, v[i].second.c_str());
}

int main() {
    vector<pair<int, string> > v;
    TopKMinElem<int, string> tk(4);;
    tk.insert(1, string("first"));
    show(tk.getTopK());

    tk.insert(2, string("second"));
    show(tk.getTopK());

    tk.insert(3, string("third"));
    show(tk.getTopK());

    tk.insert(-1, string("fourth"));
    show(tk.getTopK());

    tk.insert(100, string("fivth"));
    show(tk.getTopK());

    tk.insert(0, string("sixth"));
    show(tk.getTopK());

    return 0;
}

运行结果为:


---------
[0] (1, first)

---------
[0] (1, first)
[1] (2, second)

---------
[0] (1, first)
[1] (2, second)
[2] (3, third)

---------
[0] (-1, fourth)
[1] (1, first)
[2] (2, second)
[3] (3, third)

---------
[0] (-1, fourth)
[1] (1, first)
[2] (2, second)
[3] (3, third)

---------
[0] (-1, fourth)
[1] (0, sixth)
[2] (1, first)
[3] (2, second)

方案3

有想要精益求精的朋友可能会想——有没有更好的办法呢?
其实一般在思考更好的办法的时候,我们都应该先清楚考虑当前所有的方案的优劣。
方案2已经很好了呀,哪里还可以改善呢?
关键在于K,使用堆来优化,而不是插入排序那样一个一个比较过去的话,时间复杂度可以优化到O(Nlog K)!!!
怎么实现呢?
!!!注意,继续阅读下去的话,需要懂得堆这个数据结构以及它的数组实现!!!不然应该是看不懂的。

我的做法是——使用最大堆,维护一个大小为K的最大堆,算法流程为:
1. 当堆的大小小于K,则正常地将元素插入到堆里头,形成最大堆。
2. 当堆的大小达到K时,以后的每次插入元素,都将新元素与堆顶元素比较,保留较小的元素,如果堆顶元素有变化,则调用一遍“维护堆的性质的函数”。
3. END!
这个流程很简单,也很容易理解,前提是懂得堆怎么实现,下面贴一下我的实现和测试代码:

// TopKMinHeap.h
#ifndef __TOP_K_MIN_HEAP__
#define __TOP_K_MIN_HEAP__

#include <stdio.h>
#include <vector>
#include <stdexcept>
using namespace std;

template<typename KeyType, typename ElemType>
class TopKMinHeap
{
private:
    size_t size, capacity;
    vector<pair<KeyType, ElemType> > data;

    size_t parent(size_t i) {
        return i >> 1;
    }

    size_t leftChild(size_t i) {
        return (i << 1);
    }

    size_t rightChild(size_t i) {
        return (i << 1) | 1;
    }

public:
    // heap element array's index is started from 1
    TopKMinHeap(size_t k)
    : size(0), capacity(k), data(capacity+1) {}


    void insert(const KeyType k, const ElemType e) {
        if (size < capacity) {
            data[++size] = make_pair(k, e);
            MaxHeapLiftUp(size);
            return;
        }

        if (size == capacity && k < data[1].first) {
            data[1] = make_pair(k, e);
            MaxHeapify(1);
        }
    }

    // the all data is clear after this operation
    vector<pair<KeyType, ElemType> > getTopK() {
        vector<pair<KeyType, ElemType> > tops;
        while (size >= 1) {
            tops.push_back(data[1]);
            if (size > 1) {
                swap(data[1], data[size]);
                --size;
                MaxHeapify(1);
            } else break;
        }
        return vector<pair<KeyType, ElemType> >(tops.rbegin(), tops.rend());
    }

private:
    void MaxHeapLiftUp(size_t pos) {
        size_t father = parent(pos);
        while (pos > 1 && data[pos].first > data[father].first) {
            swap(data[father], data[pos]);
            pos = father;
            father = parent(pos);
        }
    }

    void MaxHeapify(size_t pos) {
        size_t left = leftChild(pos), right = rightChild(pos), largest = pos;
        if (left <= size && data[left].first > data[largest].first)
            largest = left;
        if (right <= size && data[right].first > data[largest].first)
            largest = right;
        if (largest != pos) {
            swap(data[pos], data[largest]);
            MaxHeapify(largest);
        }
    }
};

#endif

然后是测试的代码:

// main.cpp
#include <stdio.h>
#include <ctime>
#include <vector>
#include <string>

#include "TopKMinHeap.h"

using namespace std;

void show(const vector<pair<int, string> >& v) {
    printf("\n---------\n");
    for (int i = 0; i < v.size(); ++i)
        printf("[%d] (%d, %s)\n", i, v[i].first, v[i].second.c_str());
}

int main() {
    vector<pair<int, string> > v;
    TopKMinHeap<int, string> tkh(4);;
    TopKMinHeap<int, string> aux(4);
    tkh.insert(1, string("first"));
    aux = tkh;
    v = aux.getTopK();
    show(v);

    tkh.insert(2, string("second"));
    aux = tkh;
    v = aux.getTopK();
    show(v);

    tkh.insert(3, string("third"));
    aux = tkh;
    v = aux.getTopK();
    show(v);

    tkh.insert(-1, string("fourth"));
    aux = tkh;
    v = aux.getTopK();
    show(v);

    tkh.insert(100, string("fivth"));
    aux = tkh;
    v = aux.getTopK();
    show(v);

    tkh.insert(0, string("sixth"));
    aux = tkh;
    v = aux.getTopK();
    show(v);

    return 0;
}

输出结果为:


---------
[0] (1, first)

---------
[0] (1, first)
[1] (2, second)

---------
[0] (1, first)
[1] (2, second)
[2] (3, third)

---------
[0] (-1, fourth)
[1] (1, first)
[2] (2, second)
[3] (3, third)

---------
[0] (-1, fourth)
[1] (1, first)
[2] (2, second)
[3] (3, third)

---------
[0] (-1, fourth)
[1] (0, sixth)
[2] (1, first)
[3] (2, second)
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Jacketinsysu

你的鼓励将是我创作的最大动力

¥2 ¥4 ¥6 ¥10 ¥20
输入1-500的整数
余额支付 (余额:-- )
扫码支付
扫码支付:¥2
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值