在归并排序中,一般使用的是二路归并排序,而在二路归并的时候,每次比较,找到最小的数字都是O(1)的操作,而在归并K个路得时候,也就是K路归并,会出现一个问题,每次找到最小的数字是一个 O(k)
的操作。
那么能不能简化这个操作呢?
可以的,因为我们要归并的是K个链表,所以头结点的位置变化,并不会导致节点后面的信息丢失。
所以要先对K个头结点进行排序,然后把最小的删掉,再排序,那么这种情况最适合的就是堆排序了。
复杂度每次找到最小的,然后重新整理堆,复杂度是O(log n).
代码如下:{我写的这次的代码有些乱,将就着看把)
/**
* Definition for singly-linked list. */
struct ListNode {
int val;
ListNode *next;
ListNode(int x) : val(x), next(NULL) {}
};
class Slu {
public:
ListNode* mergeKLists(vector<ListNode*>& lists) {
int k = lists.size();
if(k == 0) return NULL;
vector<ListNode*> arr(k+1);
int j = 1;
for(int i = 1; i <= k; i++){
if(lists[i-1] == NULL) continue;
arr[j] = lists[i-1];
j++;
}
k = j-1;
buildHeap(arr, k);
ListNode * head = new ListNode(0);
ListNode * tempHead = head;
ListNode * temp;
while(k > 1){
temp = deleteHead(arr, k);
temp->next = tempHead->next;
tempHead->next = temp;
tempHead = temp;
}
if(k == 1 && arr[1] != NULL){
tempHead->next = arr[1];
}
return head->next;
}
ListNode* deleteHead(vector<ListNode*> &arr, int &k){
ListNode * res = arr[1];
arr[1] = arr[1]->next;
if(arr[1] == NULL){
int i = 1;
for(i = 1; i*2 < k; ){
if(arr[i*2]->val < arr[i*2+1]->val){
arr[i] = arr[i*2];
i *= 2;
}else{
arr[i] = arr[i*2+1];
i = i*2+1;
}
if(arr[k]->val < arr[i/2]->val){
arr[i/2] = arr[k];
i /= 2;
break;
}
}
arr[i] = arr[k];
k -= 1;
}else{
for(int i = 1; i*2 <= k;){
if(i*2 == k){
if(arr[i]->val > arr[i*2]->val){
swap(arr[i], arr[i*2]);
return res;
}else{
return res;
}
}
if(arr[i]->val <= arr[i*2]->val && arr[i]->val <= arr[i*2+1]->val){
return res;
}
if(arr[i*2]->val < arr[i]->val && arr[i*2]->val <= arr[i*2+1]->val){
swap(arr[i*2], arr[i]);
i*=2;
continue;
}
if(arr[i*2+1]->val < arr[i]->val && arr[i*2+1]->val <= arr[i*2]->val){
swap(arr[i*2+1], arr[i]);
i = i*2 + 1;
continue;
}
}
}
return res;
}
void buildHeap(vector<ListNode*> &arr, int &length){
for(int j = length/2; j > 0; j--){
for(int i = j; i <= length/2;){
if(i * 2 < length){
if(arr[i*2]->val <= arr[i]->val && arr[i*2]->val <= arr[i*2+1]->val){
ListNode * temp = arr[i*2];
arr[i*2] = arr[i];
arr[i] = temp;
i *= 2;
}else if(arr[i*2+1]->val <= arr[i]->val && arr[i*2+1]->val <= arr[i*2]->val){
ListNode * temp = arr[i*2+1];
arr[i*2+1] = arr[i];
arr[i] = temp;
i = (i*2+1);
}else{
break;
}
}else{
if(arr[i*2]->val < arr[i]->val){
ListNode * temp = arr[i*2];
arr[i*2] = arr[i];
arr[i] = temp;
i *= 2;
}else{
break;
}
}
}
}
}
};
ListNode * buildList(vector<int> vec){
ListNode * temp = new ListNode(vec[0]);
for (int i = vec.size()-1; i > 0; i--) {
ListNode * td = new ListNode(vec[i]);
td->next = temp->next;
temp->next = td;
}
return temp;
}
int main(int argc, const char * argv[]) {
vector<vector<int>> res = {{-8,-7,-6,-5,-3,-2,0},{-9,-5,1,2,2,4,4},{-3,-3,-3,-2,2},{-9,-6,-6,-6,-4,-3,2},{-8,-7,-3,-2,0,1,4},{-4,0},{-10,-2,-1,1,1},{-10}};
vector<ListNode *> vec;
for (int i = 0; i < res.size(); i++) {
vec.push_back(buildList(res[i]));
}
Slu slu;
ListNode* tp = slu.mergeKLists(vec);
while (tp != NULL) {
cout<<tp->val<<endl;
tp = tp->next;
}
return 0;
}