利用快速排序partition:
#include <iostream>
#include <map>
#include <algorithm>
#include <limits.h>
#include <assert.h>
using namespace std;
int selectK(int num[], int k, int left, int right) {
assert(k <= (right - left + 1) && k >= 1);
int mid = (left + right) / 2, i = left, j = right, pivot = num[mid];
while (i <= j) {
while (num[i] < pivot) {
++i;
}
while (num[j] > pivot) {
--j;
}
if (i <= j) {
swap(num[i], num[j]);
++i, --j;
}
}
if (k == i - left && i - 1 == j + 1) {
return pivot;
}
else if (k <= i - left) {
return selectK(num, k, left, i-1);
}
else if (k > i - left) {
return selectK(num, k - (i - left), i, right);
}
}
int main() {
int num[] = { 3, 2, 1, 4, 5, 6 };
int res1 = selectK(num, 1, 0, 5);
int res2 = selectK(num, 2, 0, 5);
int res3 = selectK(num, 3, 0, 5);
int res4 = selectK(num, 4, 0, 5);
int res5 = selectK(num, 5, 0, 5);
int res6 = selectK(num, 6, 0, 5);
//int res6 = selectK(num, 6, 0, 4);
return 0;
}
[l,i-1] <= pivot <= [j+1,r]
Python version:
from itertools import permutations
#1. must confirm num[left, i-1] <= pivot <= num[j+1, right], that's why <= couldn't be replaced by < since it's impossible.
def selectKth(num, k, left, right):
assert (k <= right-left+1 and k >= 1)
mid = (left+right)//2
i,j,pivot=left,right,num[mid]
while (i <= j): #2. must be <= instead of < since the returned condition required, otherwise, the recursive depth will be exceeded
while (num[i] < pivot): #3. must be < instead of <= since num[left, i-1] <= pivot <= num[j+1, right]
i = i + 1
while (num[j] > pivot): #3. must be > instead of >= since num[left, i-1] <= pivot <= num[j+1, right]
j = j - 1
if (i <= j): #2. must be <= instead of < since the returned condition required, otherwise, the recursive depth will be exceeded
num[i],num[j]=num[j],num[i]
i,j = i+1,j-1
#4. For num[left, i-1] <= pivot <= num[j+1, right], there're 2 cases for breaking the loop as the figure illustrated
if (k == i-left and i-1==j+1 and pivot == num[i-1]):
return pivot
elif (k <= i-left):
return selectKth(num, k, left, i-1)
elif (k > i-left):
return selectKth(num, k-(i-left), i, right)
if __name__ == '__main__':
perms = permutations([1,2,3,4,5,6], 6)
for i in perms:
for k in range(1, 7):
perm = list(i)
beforeSelect = list(i)
kthNum = selectKth(perm, k, 0, 5)
print("perm={0} k={1} kth={2}".format(beforeSelect, k, kthNum))
assert (k == kthNum)
BTW, there's a O(n) method for selecting the first k numbers. Think over it!!!
from itertools import permutations
# must confirm num[left, i-1] <= pivot <= num[j+1, right], that's why <= couldn't be replaced by < since it's impossible.
def selectTopK(num, k, left, right):
assert (k <= right - left + 1 and k >= 1)
mid = (left + right) // 2
i, j, pivot = left, right, num[mid]
# must i <= j,
while (i <= j):
while (num[i] < pivot):
i = i + 1
while (num[j] > pivot):
j = j - 1
if (i <= j):
num[i], num[j] = num[j], num[i]
i, j = i + 1, j - 1
# num[left, i-1] <= pivot <= num[j+1, right]
if (k == i - left and i - 1 == j + 1 and pivot == num[i - 1]):
return i - 1
elif (k <= i - left):
return selectTopK(num, k, left, i - 1)
elif (k > i - left):
return selectTopK(num, k - (i - left), i, right)
if __name__ == '__main__':
raw = [1, 2, 3, 1, 2, 3]
rawlen = len(raw)
perms = permutations(raw, rawlen)
for i in perms:
for k in range(1, rawlen+1):
beforeSelect = list(i)
kthIndex = selectTopK(beforeSelect, k, 0, rawlen-1)
selectedTopK = beforeSelect[:kthIndex + 1]
selectedTopK.sort()
beforeSelect2 = list(i)
beforeSelect2.sort()
topK = beforeSelect2[:kthIndex + 1]
print("perm={0} selected={1} topK={2}".format(beforeSelect, selectedTopK, topK))
assert (topK == selectedTopK)