相信许多人都发现并好奇,.Net Framework 为什么没有为我们提供优先队列这种数据结构的封装实现?其实不然,基于红黑树实现的SortedDictionary泛型类恰是优先队列的一种实现,而且其功能更为灵活。不过对于我而言,却总希望写一个利用二项堆实现的优先队列,尽管这两者从算法复杂度的层面上而言相差无几,都能在O(log2N)内提供添加和移除操作,但实际上,算法编码的复杂度也会或多或少地影响了数据结构的实际运作性能。在我想来,二项堆的结构逻辑远比红黑树简单得多,在实际需要优先队列的场景中,二项堆的实现应该能提供更好的性能。
既然要造轮子,那么我们就要造一个最好的轮子!其实在写这篇博客之前,我也搜索了许多相关的博客,许多博主UP出来的代码都过于简单,和教科书上的代码没啥两样。要知道,教科书上的代码只是为了阐述算法思想而存在,其完整性和严谨性都不足以成为工业级代码,这也是为什么明明我们数据结构学得不错,但是当我们去看微软的源码实现时却觉得一头雾水。为此,我专门写这篇博客,为读者们提供一个更加完整和严谨的二项堆实现。
以下的代码实现经过测试验证可用,从VS2017提供的性能测试数据来看,该实现方案的运行效率要比SortedDictionary高出不少,在10万数据量的增加和移除测试中,SortedDictionary耗时70ms,而我们的实现仅需35ms。
/*******************************************************************
* 版权所有: 深圳市震有科技有限公司
* 文件名称: BinaryHeap.cs
* 作 者: 李垄华
* 创建日期: 2018-03-12 10:14:42
* 文件版本: 1.0.0.0
* 修改时间: 修改人: 修改内容:
*******************************************************************/
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
namespace LWLCX.Common.Collections.Generics
{
/// <summary>
/// 最大优先二叉堆
/// </summary>
/// <remarks>
/// <para>
/// 二叉堆是一种支持动态操作的集合数据结构,其内部维护用户数据的优先级顺序性。
/// 无论用户以何种顺序添加数据,都仅能取出(或移除)当前集合中优先级最高的元素。
/// </para>
/// <para>
/// <see cref="Push"/>方法可向集合中添加新的<typeparamref name="TValue"/>元素,
/// 同时指定该元素对应的<typeparamref name="TKey"/>键,集合内部依据<typeparamref name="TKey"/>键
/// 间的关系维护数据的顺序性。
/// </para>
/// <para>
/// <see cref="Pop"/>方法可从集合中取出并移除优先级最高的元素;而<see cref="Peek"/>方法
/// 则仅从集合中取出优先级最高的元素,但并不移除。
/// </para>
/// <para>
/// 需要注意的是,<see cref="Enumerator"/>的遍历顺序并不遵循上述定义,即遍历操作是无序的。
/// </para>
/// <para>
/// 此外,<see cref="BinaryHeap{TKey,TValue}"/>不提供线程安全访问保证。
/// </para>
/// </remarks>
/// <typeparam name="TKey">键类型</typeparam>
/// <typeparam name="TValue">值类型</typeparam>
[DebuggerDisplay("Count = {" + nameof(Count) + "}")]
public class BinaryHeap<TKey, TValue> : IReadOnlyCollection<KeyValuePair<TKey, TValue>>, ICollection
{
#region Segment
/// <summary>
/// 数据段
/// </summary>
/// <remarks>
/// <para>
/// 二叉堆内部采用链式数据段作为数据的存储容器。整个二叉堆的内部存储容器可以抽象地看成是一个连续的可变长数组,而
/// 该数组实际却是由多个Segment数据段链接而成。
/// 每个Segment中都有内置的数组<see cref="Segment.Data"/>作为容器,但该内置数组的长度却不是固定的。
/// 对于链中的任意一个Segment段而言,其内置数组的长度大概是该Segment段之前的所有Segment段的内置数组长度之和。
/// </para>
/// <para>
/// 我们把任意一个Segment中的任意一个元素在整个堆中的下标称为逻辑下标。出于方便,整个堆的逻辑下标从1开始计算。
/// 例如,对于某一个Segment段而言,<see cref="Segment.Data"/>[i]的逻辑下标应该是<see cref="Segment.Offset"/> + i。
/// </para>
/// <para>
/// 从逻辑上来看,二叉堆的内部元素关系是树状关系,由于它符合完全二叉树的特性,所以我们可以使用逻辑可变长数组进行存储,
/// 而这种逻辑可变长数组又是依靠链式Segment来实现的。对于任意一个数据段内部的元素结点而言,其逻辑左右子结点必然存在于当前
/// 数据段或紧接的下一个数据段内;其逻辑父结点必然存在于当前数据段或紧接的上一个数据段内。故而我们可以在O(1)的复杂度内
/// 找到某个元素的父结点或子结点。
/// </para>
/// </remarks>
private class Segment
{
#region Fields
// 前一个数据段
private readonly Segment m_previous;
// 数据段所属堆对象
private readonly BinaryHeap<TKey, TValue> m_source;
// 下标上界(此数据段中的所有元素在整个堆中的对应逻辑下标皆小于此上界)
// m_upBound = Offset + Capacity;
private readonly int m_upBound;
// 数组下标偏移
public readonly int Offset;
#endregion
#region Constructors
// 此构造用于为二叉堆扩展新的数据段.
private Segment(Segment previous)
{
m_previous = previous;
m_source = previous.m_source;
// 上一个数据段中最末元素的逻辑下标
Offset = previous.m_upBound;
Data = new KeyValuePair<TKey, TValue>[Math.Max(Offset, 16)];
m_upBound = Offset + Capacity;
previous.Next = this;
}
// 此构造用于创建出堆中的第一个数据段
public Segment(BinaryHeap<TKey, TValue> source)
{
m_source = source;
Data = new KeyValuePair<TKey, TValue>[16];
Offset = 1;
m_upBound = Offset + Capacity;
}
// 此构造用于创建出堆中的第一个数据段, 并附带初始化数据
public Segment(IEnumerable<KeyValuePair<TKey, TValue>> initData,
BinaryHeap<TKey, TValue> source)
{
m_previous = null;
m_source = source;
Offset = 1;
Data = initData.Concat(new[] {new KeyValuePair<TKey, TValue>()}).ToArray();
Length = Data.Length - 1;
m_upBound = Offset + Capacity;
source.m_head = source.m_tail = this;
if (Length > 1)
{
for (int i = BinaryHeapHelper.GetParentIndex(CurrentLastIndex); i > 0; --i)
{
MaxHeapify(this, i);
}
}
}
#endregion
#region Properties
// 下一个数据段
public Segment Next { get; private set; }
// 数据段容量
private int Capacity => Data.Length;
// 获取当前段已有数据长度
public int Length { get; private set; }
// 获取此数据段中最后一个元素在整个堆中的逻辑下标
private int CurrentLastIndex => Offset + Length - 1;
// 数据存储数组
public KeyValuePair<TKey, TValue>[] Data { get; }
// 数据段是否为空
private bool IsEmpty => Length == 0;
// 数据段是否已满
private bool IsFull => Length == Data.Length;
#endregion
#region Methods
// 尝试在此数据段中追加元素, 返回其下标
public void Push(KeyValuePair<TKey, TValue> item)
{
int localIndex = Length++;
// 本数据段满了
if (IsFull)
{
m_source.m_tail = new Segment(this);
}
var comparer = m_source.m_comparer;
var currentSeg = this;
var currentIndex = localIndex + Offset;
var parentSeg = currentSeg;
var parentIndex = BinaryHeapHelper.GetParentIndex(currentIndex);
while (parentIndex > 0)
{
// 由于父结点有可能存在于当前段或前一段内, 所以需要根据其逻辑下标判断
if (parentIndex < currentSeg.Offset)
{
parentSeg = currentSeg.m_previous;
}
// 如果新项的优先级高于父结点优先级
if (comparer.Compare(item.Key, parentSeg.Data[parentIndex - parentSeg.Offset].Key) > 0)
{
// 则父结点向下沉
currentSeg.Data[currentIndex - currentSeg.Offset] =
parentSeg.Data[parentIndex - parentSeg.Offset];
currentSeg = parentSeg;
currentIndex = parentIndex;
parentIndex = BinaryHeapHelper.GetParentIndex(currentIndex);
}
else
{
break;
}
}
currentSeg.Data[currentIndex - currentSeg.Offset] = item;
}
// 移除并返回数据段的末尾元素
public KeyValuePair<TKey, TValue> Pop()
{
//Contract.Assert(m_source.m_tail == this);
if (IsEmpty)
{
//Contract.Assert(m_previous != null);
m_previous.Next = null;
m_source.m_tail = m_previous;
return m_previous.Pop();
}
var head = m_source.m_head;
var maxItem = head.Data[0];
head.Data[0] = Data[--Length];
if (!(head == this && IsEmpty))
{
MaxHeapify(head, head.Offset);
}
return maxItem;
}
/// <summary>
/// 对指定元素执行堆下沉操作
/// </summary>
/// <param name="seg">指定元素所属数据段</param>
/// <param name="nodeIndex">指定元素在整个堆中的下标</param>
private static void MaxHeapify(Segment seg, int nodeIndex)
{
//Contract.Assert(nodeIndex <= seg.CurrentLastIndex && nodeIndex >= seg.Offset);
//Contract.Assert(seg != null);
BinaryHeapHelper.GetChildrenIndex(nodeIndex, out var leftIndex, out var rightIndex);
int maxIndex = nodeIndex; // 最优先结点下标(初始化为当前结点下标)
Segment maxSeg = seg;
var comparer = seg.m_source.m_comparer;
do
{
Segment leftSeg = seg, rightSeg = seg;
// 左子结点有可能在当前段, 也有可能在下一个段, 所以此处需要通过下标判断
if (leftIndex >= seg.m_upBound)
{
leftSeg = seg.Next;
}
// 如果存在左子结点
if (leftSeg != null && leftIndex <= leftSeg.CurrentLastIndex)
{
// 且其优先级比当前结点大
if (comparer.Compare(leftSeg.Data[leftIndex - leftSeg.Offset].Key,
maxSeg.Data[maxIndex - maxSeg.Offset].Key) > 0)
{
maxSeg = leftSeg;
maxIndex = leftIndex;
}
}
// 如果存在右子结点且其优先级比最优先结点大
if (rightIndex >= seg.m_upBound)
{
rightSeg = seg.Next;
}
if (rightSeg != null && rightIndex <= rightSeg.CurrentLastIndex)
{
if (comparer.Compare(rightSeg.Data[rightIndex - rightSeg.Offset].Key,
maxSeg.Data[maxIndex - maxSeg.Offset].Key) > 0)
{
maxSeg = rightSeg;
maxIndex = rightIndex;
}
}
// 如果最优先结点不是当前结点
if (maxIndex != nodeIndex)
{
BinaryHeapHelper.Swap(ref maxSeg.Data[maxIndex - maxSeg.Offset],
ref seg.Data[nodeIndex - seg.Offset]);
seg = maxSeg;
nodeIndex = maxIndex;
BinaryHeapHelper.GetChildrenIndex(nodeIndex, out leftIndex, out rightIndex);
}
else
{
break;
}
} while (true);
}
#endregion
}
#endregion
#region Enumerator
/// <summary>
/// 此枚举器用于遍历堆中的所有元素(非有序遍历)
/// </summary>
public struct Enumerator : IEnumerator<KeyValuePair<TKey, TValue>>
{
private readonly BinaryHeap<TKey, TValue> m_heap;
/*
* 在枚举器遍历的过程中, 我们不允许m_heap对象被外部更改,
* 所以我们需要通过m_version来进行版本号校对
*/
private readonly ulong m_version;
private Segment m_currentSeg;
private int m_currentIndex;
/// <summary>
/// 初始化枚举器
/// </summary>
/// <param name="heap">关联堆对象</param>
/// <remarks>算法复杂度: O(1)</remarks>
public Enumerator(BinaryHeap<TKey, TValue> heap)
{
m_heap = heap ?? throw new ArgumentNullException(nameof(heap));
m_heap.CheckAvailable();
// 记录初始版本号
m_version = m_heap.m_version;
m_currentSeg = m_heap.m_head;
m_currentIndex = -1;
}
public void Dispose()
{
}
/// <inheritdoc />
/// <remarks>算法复杂度: O(1)</remarks>
public bool MoveNext()
{
// 版本号校对
if (m_version != m_heap.m_version)
{
throw new InvalidOperationException("The heap has been changed.");
}
if (m_currentIndex == m_currentSeg.Length - 1)
{
if (m_currentSeg.Next == null)
{
return false;
}
m_currentSeg = m_currentSeg.Next;
m_currentIndex = 0;
return true;
}
++m_currentIndex;
return true;
}
/// <inheritdoc />
/// <remarks>算法复杂度: O(1)</remarks>
public void Reset()
{
// 版本号校对
if (m_version != m_heap.m_version)
{
throw new InvalidOperationException("The heap has been changed.");
}
m_currentSeg = m_heap.m_head;
m_currentIndex = -1;
}
public KeyValuePair<TKey, TValue> Current => m_currentSeg.Data[m_currentIndex];
object IEnumerator.Current => Current;
}
#endregion
#region Fields
// 头段
private Segment m_head;
// 尾段
private Segment m_tail;
// 优先级比较器
private readonly IComparer<TKey> m_comparer;
// 操作版本号
private ulong m_version;
#endregion
#region Properties
/// <summary>
/// 获取二叉堆中元素的数量
/// </summary>
public int Count => m_tail.Offset + m_tail.Length - 1;
/// <summary>
/// 检验二叉堆是否为空
/// </summary>
public bool IsEmpty => m_head.Length == 0;
/// <summary>
/// 指示二叉堆结构是否可用.
/// </summary>
/// <remarks>
/// 由于<see cref="Push"/>操作和<see cref="Pop"/>操作会改变堆内部的存储结构,
/// 期间依赖于<see cref="Comparer"/>方法的运行结果。如果在此过程中<see cref="Comparer"/>
/// 内部抛出异常,将导致整个算法流程中断,从而使得此数据结构损坏。
/// </remarks>
public bool Available { get; private set; } = true;
/// <summary>
/// 获取当前集合使用的优先级比较器
/// </summary>
public IComparer<TKey> Comparer => m_comparer;
object ICollection.SyncRoot => throw new NotSupportedException();
bool ICollection.IsSynchronized => false;
#endregion
#region Constructors
/// <summary>
/// 初始化一个空的二叉堆
/// </summary>
/// <remarks>算法复杂度: O(1)</remarks>
public BinaryHeap()
: this(Comparer<TKey>.Default)
{
}
/// <summary>
/// 使用指定数据集合<paramref name="data" />初始化一个二叉堆
/// </summary>
/// <param name="data">指定数据集合</param>
/// <remarks>算法复杂度: O(n)</remarks>
public BinaryHeap(IEnumerable<KeyValuePair<TKey, TValue>> data)
: this(Comparer<TKey>.Default, data)
{
}
/// <summary>
/// 初始化一个空的二叉堆, 并指定内部使用的比较器
/// </summary>
/// <param name="comparer">内部使用的比较器</param>
/// <remarks>算法复杂度: O(1)</remarks>
public BinaryHeap(IComparer<TKey> comparer)
{
m_comparer = comparer ?? Comparer<TKey>.Default;
m_head = m_tail = new Segment(this);
}
/// <summary>
/// 使用指定数据集合<paramref name="data" />初始化一个二叉堆, 并指定内部使用的比较器
/// </summary>
/// <param name="comparer">指定内部使用的比较器</param>
/// <param name="data">指定数据集合</param>
/// <remarks>算法复杂度: O(n)</remarks>
public BinaryHeap(IComparer<TKey> comparer, IEnumerable<KeyValuePair<TKey, TValue>> data)
{
if (data == null)
{
throw new ArgumentNullException(nameof(data));
}
m_comparer = comparer ?? Comparer<TKey>.Default;
m_head = m_tail = new Segment(data, this);
}
#endregion
#region Public Methods
/// <summary>
/// 添加新项到二叉堆
/// </summary>
/// <param name="key"></param>
/// <param name="value">被添加的元素</param>
/// <remarks>算法复杂度: O(logN)</remarks>
public void Push(TKey key, TValue value)
{
CheckAvailable();
++m_version;
try
{
m_tail.Push(new KeyValuePair<TKey, TValue>(key, value));
}
catch (Exception)
{
Available = false;
throw;
}
}
/// <summary>
/// 清空此二叉堆
/// </summary>
/// <remarks>算法复杂度: O(1)</remarks>
public void Clear()
{
++m_version;
m_tail = m_head = new Segment(this);
Available = true;
}
/// <summary>
/// 从二叉堆中移除并返回最高优先的元素
/// </summary>
/// <returns>最高优先的元素</returns>
/// <remarks>算法复杂度: O(logN)</remarks>
public KeyValuePair<TKey, TValue> Pop()
{
CheckAvailable();
if (IsEmpty)
{
throw new InvalidOperationException("Can't execute dequeue while the queue is empty.");
}
++m_version;
try
{
// m_head.Data[0] 是二叉堆中的最优先元素
return m_tail.Pop();
}
catch (Exception)
{
Available = false;
throw;
}
}
/// <summary>
/// 从二叉堆中提取但不移除最高优先的元素
/// </summary>
/// <returns>最高优先的元素</returns>
/// <remarks>算法复杂度: O(1)</remarks>
public KeyValuePair<TKey, TValue> Peek()
{
CheckAvailable();
if (IsEmpty)
{
throw new InvalidOperationException("Can't execute Peek() while the queue is empty.");
}
// m_head.Data[0] 是二叉堆中的最优先元素
return m_head.Data[0];
}
/// <summary>
/// 获取枚举器
/// </summary>
/// <returns></returns>
/// <remarks>算法复杂度: O(1)</remarks>
public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator()
{
return new Enumerator(this);
}
/// <summary>
/// 合并两个二叉堆
/// </summary>
/// <param name="otherHeap"></param>
/// <param name="comparer"></param>
/// <returns>合并后的新二叉堆</returns>
/// <remarks>算法复杂度: O(n)</remarks>
public BinaryHeap<TKey, TValue> UnionWith(BinaryHeap<TKey, TValue> otherHeap, IComparer<TKey> comparer)
{
if (otherHeap == null)
{
throw new ArgumentNullException(nameof(otherHeap));
}
return new BinaryHeap<TKey, TValue>(comparer, this.Concat(otherHeap));
}
// 隐藏实现
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
/// <summary>
/// 把当前二叉堆中的元素复制到<paramref name="array" />. 但并不保证其优先级顺序性.
/// </summary>
/// <param name="array">指定的目标数组</param>
/// <param name="index"><paramref name="array" />内的可用起始下标</param>
/// <remarks>算法复杂度: O(n)</remarks>
public void CopyTo(Array array, int index)
{
CheckAvailable();
if (array == null)
{
throw new ArgumentNullException(nameof(array));
}
if (index < 0 || index >= array.Length)
{
throw new ArgumentOutOfRangeException(nameof(index));
}
if (array.Length - index < Count)
{
throw new ArgumentException("数组长度不足", nameof(array));
}
var seg = m_head;
while (seg != m_tail)
{
seg.Data.CopyTo(array, index);
index += seg.Length;
seg = seg.Next;
}
if (m_tail.Length > 0)
{
Array.Copy(m_tail.Data, 0, array, index, m_tail.Length);
}
}
#endregion
#region Private methods
/// <summary>
/// 检验此结构是否可用
/// </summary>
private void CheckAvailable()
{
if (!Available)
{
throw new InvalidOperationException("The inner structure has been broken.");
}
}
#endregion
}
/// <summary>
/// 帮助类
/// </summary>
internal static class BinaryHeapHelper
{
/// <summary>
/// 交换俩元素
/// </summary>
/// <param name="a"></param>
/// <param name="b"></param>
public static void Swap<T>(ref T a, ref T b)
{
var c = a;
a = b;
b = c;
}
/// <summary>
/// 获取指定下标为<paramref name="index" />的元素的父结点下标
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
public static int GetParentIndex(int index)
{
return index >> 1;
}
/// <summary>
/// 获取指定下标为<paramref name="nodeIndex" />的元素的左右子结点下标
/// </summary>
/// <param name="nodeIndex"></param>
/// <param name="leftChildIndex"></param>
/// <param name="rightChildIndex"></param>
public static void GetChildrenIndex(int nodeIndex, out int leftChildIndex, out int rightChildIndex)
{
leftChildIndex = nodeIndex << 1;
rightChildIndex = leftChildIndex + 1;
}
}
}
如果你发现或测出了什么Bug,请在评论区指出。