最近在复习王道考研数据结构,看到败者树这个视频觉得很新奇,以往实现k路归并都是用堆实现的,因此记录一下该算法的实现和实践。
王道考研数据结构-败者树
leetcode23.合并k个升序链表
概念
使用k路平衡归并策略,选出一个最小元素需要对比关键字
(
k
−
1
)
(k-1)
(k−1)次,导致内部归并所需时间增加
考虑用败者树优化,败者树是一棵完全二叉树,k个叶结点分别是当前参加比较的元素,非叶结点用来记忆左右子树中的“失败者”,胜者往上继续比较,一直到根结点。树中记录的是归并段的序号而不是具体数值,这样做的好处是不需要每一轮结束后补充新的值到叶结点,只需要不断重复pk过程,判断某个归并段空的情况即可。
初始化败者树时需要更新完全二叉树中编号小于等于
⌊
n
2
⌋
\left \lfloor \frac{n}{2}\right \rfloor
⌊2n⌋的结点,时间复杂度为
O
(
n
)
O(n)
O(n),n为总结点数,又因为败者树必定不存在度为1的结点,所以可得叶结点数k与结点总数的关系为:
k
=
n
+
1
2
k=\frac{n+1}{2}
k=2n+1,n必定为奇数,完全二叉树树高为
h
=
⌈
l
o
g
2
(
n
+
1
)
⌉
h=\left \lceil log_2(n+1) \right \rceil
h=⌈log2(n+1)⌉。
由此可知,我们设置完全二叉树的顺序存储空间大小时只需要开 2 k 2k 2k的大小,叶结点下标从 k k k开始到 2 k − 1 2k-1 2k−1结束,非叶节点下标从 1 1 1开始到 k − 1 k-1 k−1结束。
每轮pk的时间复杂度是 O ( l o g 2 n ) O(log_2n) O(log2n),pk的总轮数为所有归并段的数据总数,总时间复杂度为 O ( m l o g 2 k ) O(mlog_2k) O(mlog2k),k为归并段的数量,m为归并段的数据总数。
胜者树和败者树区别仅在于非叶结点记录的是败者编号还是胜者编号,记录胜者编号更便于pk
题意
合并K个升序链表
给你一个链表数组,每个链表都已经按升序排列。
请你将所有链表合并到一个升序链表中,返回合并后的链表。
样例
输入:lists = [[1,4,5],[1,3,4],[2,6]]
输出:[1,1,2,3,4,4,5,6]
解释:链表数组如下:
[
1->4->5,
1->3->4,
2->6
]
将它们合并到一个有序链表中得到。
1->1->2->3->4->4->5->6
输入:lists = []
输出:[]
输入:lists = [[]]
输出:[]
数据范围
提示:
- k = = l i s t s . l e n g t h k == lists.length k==lists.length
- 0 < = k < = 1 0 4 0 <= k <= 10^4 0<=k<=104
- 0 < = l i s t s [ i ] . l e n g t h < = 500 0 <= lists[i].length <= 500 0<=lists[i].length<=500
- − 1 0 4 < = l i s t s [ i ] [ j ] < = 1 0 4 -10^4<= lists[i][j]<=10^4 −104<=lists[i][j]<=104
- l i s t s [ i ] lists[i] lists[i]按升序排列
- l i s t s [ i ] . l e n g t h 的 总 和 不 超 过 1 0 4 lists[i].length的总和不超过10^4 lists[i].length的总和不超过104
class Solution {
public:
const int inf = 1e9;
int n;
ListNode* ans, *rear;
void pk(vector<int>& tr, vector<ListNode*>& lists, int i) {
// 记录胜者编号,值最小的胜出
int l = lists[tr[i << 1]] == nullptr ? inf : lists[tr[i << 1]]->val;
int r = lists[tr[i << 1 | 1]] == nullptr ? inf : lists[tr[i << 1 | 1]]->val;
if(l < r) tr[i] = tr[i << 1];
else tr[i] = tr[i << 1 | 1];
}
bool adjust(vector<int>& tr, vector<ListNode*>& lists) {
if(lists[tr[1]] == nullptr) return false;
int pos = tr[1];
ListNode* t = lists[pos];
lists[pos] = lists[pos]->next;
t->next = nullptr;
rear->next = t, rear = rear->next;
// 上一轮胜出者被移除,其所在归并段叶结点需要向上更新
for(int i = (n + pos) / 2;i >= 1;i >>= 1) {
pk(tr, lists, i);
}
return true;
}
ListNode* mergeKLists(vector<ListNode*>& lists) {
if(lists.size() == 0) return nullptr;
int cnt = 0;
n = lists.size();
int idx = 0;
vector<int> tr(2 * n);
ans = new ListNode(); //头指针
rear = ans; //尾指针
// 叶结点记录归并段的序号
for(int i = n;i <= 2 * n - 1;i++) tr[i] = idx++;
for(int i = n - 1;i >= 1;i--) pk(tr, lists, i);
while(adjust(tr, lists));
return ans->next;
}
};