Java实现最大(小)堆以及堆排序、TopN问题
Java实现堆
什么是堆,先来了解原理,再看如何实现。
- 堆的定义:堆
(Heap)
是计算机科学中一类特殊的数据结构的统称。堆通常是一个可以被看做一棵完全二叉树的数组对象。
堆可以看成是一棵树,并且这棵树的子树也是堆。而最大堆就是父节点大于子节点,根节点是最大的节点;最小堆就是父亲节点小于子节点,根节点是最小的节点。
要注意堆和二叉查找树的区别,二叉查找树是左子节点比父节点小,右子节点比父节点大。
如图最大堆、最小堆:
而我们知道二叉树有一条性质是:
当前节点下标与子节点下标存在着这样的关系:左子节点的下标等于父节点的两倍,右子节点的下标等于父节点的两倍加1,即left_son.index = current.index * 2
,right_son.index = current.index * 2 + 1
。
而这个二叉树又是一个完全二叉树,结合上面这个特性,那么利用数组实现堆是十分合适的。
下面步入正题,堆的构建:
堆的构建
给定一个数组{1,9,7,5,18,4}
,对应的二叉树如下:
我们需要把它调整成最大堆,也就是堆的初始化过程:可以从上往下,也可以从下往上,但是从上往下每次都要将子树遍历完,重复遍历了很多节点。所以采用自底向上方法,即从最后一个有子节点的开始(因为它是最后一个节点的父节点,所以它的下标是size / 2
),让它的两个子节点进行比较,找到大的值再和父节点的值进行比较。如果父结点的值小,则子结点和父结点进行交换,交换之后再递归往下比较。然后再一步步递归上去,直到根节点结束。
7
是最后一个有子节点的节点,和左右儿子节点比较,然后和大的换位子,然后先往下递归重复这个步骤到比它小的节点,然后再向上递归重复这个步骤到根节点,这里7是最大的,不用换,然后递归到上一层,即1
和7
比较,具体过程如下:
实现代码:
/**
* 构造函数
*/
public MaxHeap(int[] num, int capacity) {
this(capacity + 1);
if (num.length > capacity) {
try {
throw new Exception("capacity不能小于数组的长度!");
} catch (Exception e) {
e.printStackTrace();
}
}
for (int i = 1, len = num.length; i <= len; ++i) {
++size;
heap[i] = num[i - 1];
}
adjustment(); // 调整堆
}
/**
* 调整堆
*/
private void adjustment() {
for (int i = size / 2; i >= 1; --i) {
int temp = heap[i]; // 当前操作的节点,从后往前,第一个操作的就是那个有子节点的节点
int max_index = i * 2; // 当前节点的左儿子节点
while (max_index <= size) { // 当前节点的左儿子节点有效,下标在1~size代表有效
if (max_index < size && heap[max_index] < heap[max_index + 1])
++max_index;
if (temp > heap[max_index]) break;
heap[max_index / 2] = heap[max_index];
max_index *= 2;
}
heap[max_index / 2] = temp;
}
}
堆的插入
最大堆的插入: 将元素插入末尾,因为原本已经是最大堆了,所以只需层层向上比较即可,比如上图,我要插入10
,只需要和7比较,然后和18比较即可。
实现代码:
public void insert(int element) {
if (size == 0) { // 堆为空,直接插入0位置
heap[1] = element;
++size;
return;
} else if (size >= heap.length - 1) { // 容量已满,扩容
resize();
}
int index = ++size;
while (index != 1 && element > heap[index / 2]) {
heap[index] = heap[index / 2];
index /= 2;
}
heap[index] = element;
}
public void resize() {
heap = Arrays.copyOf(heap, heap.length + heap.length / 2);
}
堆的删除
删除根节点: 这就更简单了,直接将最后一个节点移到根节点,然后再调整堆,然后递归向下比较,因为根节点到叶子结点的某条路一定是有序的,调整即找个合适的位置将根节点放到那里。
实现代码:
public int removeMax() {
int max = heap[1];
heap[1] = heap[size--]; // 可以不写这句,但size要--
// adjustment();
int temp = heap[1];
int current_index = 1 * 2;
while (current_index <= size) {
if (current_index < size && heap[current_index] < heap[current_index + 1]) {
current_index++;
}
if (heap[current_index] < temp) break;
heap[current_index / 2] = heap[current_index];
current_index *= 2;
}
heap[current_index / 2] = temp;
return max;
}
具体实现代码
详细实现代码(这个只适合存放int
类型,往下面有另一种,带泛型的):
import java.util.Arrays;
public class MaxHeap {
private int[] heap;
private int size;
private final static int DEFAULT_CAPACITY = 10;
public MaxHeap() {
this(DEFAULT_CAPACITY);
}
public MaxHeap(int capacity) {
heap = new int[capacity + 1];
size = 0;
}
public MaxHeap(int[] num) {
this(num, num.length * 2);
}
public MaxHeap(int[] num, int capacity) {
this(capacity + 1);
if (num.length > capacity) {
try {
throw new Exception("capacity不能小于数组的长度!");
} catch (Exception e) {
e.printStackTrace();
}
}
for (int i = 1, len = num.length; i <= len; ++i) {
++size;
heap[i] = num[i - 1];
}
adjustment(); // 调整堆
}
/**
* 调整堆
*/
private void adjustment() {
for (int i = size / 2; i >= 1; --i) {
int temp = heap[i];
int max_index = i * 2;
while (max_index <= size) {
if (max_index < size && heap[max_index] < heap[max_index + 1])
++max_index;
if (temp > heap[max_index]) break;
heap[max_index / 2] = heap[max_index];
max_index *= 2;
}
heap[max_index / 2] = temp;
}
}
public void insert(int element) {
if (size == 0) { // 堆为空,直接插入0位置
heap[1] = element;
++size;
return;
} else if (size >= heap.length - 1) { // 容量已满,扩容
resize();
}
int index = ++size;
while (index != 1 && element > heap[index / 2]) {
heap[index] = heap[index / 2];
index /= 2;
}
heap[index] = element;
}
public void resize() {
heap = Arrays.copyOf(heap, heap.length + heap.length / 2);
}
/**
* 删除最大值
*
* @return
*/
public int removeMax() {
int max = heap[1];
heap[1] = heap[size--];
// adjustment();
int temp = heap[1];
int current_index = 1 * 2;
while (current_index <= size) {
if (current_index < size && heap[current_index] < heap[current_index + 1]) {
current_index++;
}
if (heap[current_index] < temp) break;
heap[current_index / 2] = heap[current_index];
current_index *= 2;
}
heap[current_index / 2] = temp;
return max;
}
public void print() {
for (int i = 1; i <= size; ++i) {
System.out.print(heap[i] + " ");
}
}
public static void main(String[] args) {
MaxHeap maxHeap = new MaxHeap();
maxHeap.insert(12);
maxHeap.insert(14);
maxHeap.insert(11);
maxHeap.insert(110);
maxHeap.insert(411);
maxHeap.insert(11);
maxHeap.print();
System.out.println();
System.out.println("============");
MaxHeap maxHeap1 = new MaxHeap(new int[]{12, 14, 11, 110, 411, 11});
maxHeap1.removeMax();
maxHeap1.print();
}
}
带泛型的实现(更通用):
package heap;
import java.lang.reflect.Array;
import java.util.Arrays;
public class MaxHeap_T<T extends Comparable<T>> {
private T[] heap;
private int size;
private final static int DEFAULT_CAPACITY = 10;
public MaxHeap_T(Class<T> clazz) {
this(clazz, DEFAULT_CAPACITY);
}
public MaxHeap_T(Class<T> clazz, int capacity) {
heap = (T[]) Array.newInstance(clazz, capacity);
size = 0;
}
public MaxHeap_T(Class<T> clazz, T[] num) {
this(clazz, num, num.length * 2);
}
public MaxHeap_T(Class<T> clazz, T[] num, int capacity) {
this(clazz, capacity + 1);
if (num.length > capacity) {
try {
throw new Exception("capacity不能小于数组的长度!");
} catch (Exception e) {
e.printStackTrace();
}
}
for (int i = 1, len = num.length; i <= len; ++i) {
++size;
heap[i] = num[i - 1];
}
adjustment(); // 调整堆
}
/**
* 调整堆
*/
private void adjustment() {
for (int i = size / 2; i >= 1; --i) {
T temp = heap[i];
int max_index = i * 2;
while (max_index <= size) {
if (max_index < size && heap[max_index].compareTo(heap[max_index + 1]) < 0)
++max_index;
if (temp.compareTo(heap[max_index]) > 0) break;
heap[max_index / 2] = heap[max_index];
max_index *= 2;
}
heap[max_index / 2] = temp;
}
}
public void insert(T element) {
if (size == 0) { // 堆为空,直接插入0位置
heap[1] = element;
++size;
return;
} else if (size >= heap.length - 1) { // 容量已满,扩容
resize();
}
int index = ++size;
while (index != 1 && element.compareTo(heap[index / 2]) > 0) {
heap[index] = heap[index / 2];
index /= 2;
}
heap[index] = element;
}
public void resize() {
heap = Arrays.copyOf(heap, heap.length + heap.length / 2);
}
/**
* 删除最大值
*
* @return
*/
public T removeMax() {
T max = heap[1];
heap[1] = heap[size--];
// adjustment();
T temp = heap[1];
int current_index = 1 * 2;
while (current_index <= size) {
if (current_index < size && heap[current_index].compareTo(heap[current_index + 1]) < 0) {
current_index++;
}
if (heap[current_index].compareTo(temp) < 0) break;
heap[current_index / 2] = heap[current_index];
current_index *= 2;
}
heap[current_index / 2] = temp;
return max;
}
public void print() {
for (int i = 1; i <= size; ++i) {
System.out.print(heap[i] + " ");
}
}
public static void main(String[] args) {
MaxHeap maxHeap = new MaxHeap();
maxHeap.insert(12);
maxHeap.insert(14);
maxHeap.insert(11);
maxHeap.insert(110);
maxHeap.insert(411);
maxHeap.insert(11);
maxHeap.print();
System.out.println();
System.out.println("============");
MaxHeap_T<Integer> maxHeap1 = new MaxHeap_T<>(Integer.class, new Integer[]{12, 14, 11, 110, 411, 11});
maxHeap1.removeMax();
maxHeap1.print();
}
}
堆排序
了解了堆的原理,堆排序应该很简单了,我们只需要将最后一个节点与第一个节点交换即可,换完一个size
减1
,然后调整堆,调整的时候和删除是一样的,因为每条路径都是有序的,然后循环这个步骤,直到换完。
实现代码(利用上面的代码):
public void heapSort() {
int temp = size; // 先把原大小存一下
for (int i = size; i > 0; --i) {
--size;
int t = heap[i];
heap[i] = heap[1];
heap[1] = t;
int temp = heap[1];
int current_index = 1 * 2;
while (current_index <= size) {
if (current_index < size && heap[current_index] < heap[current_index + 1]) {
current_index++;
}
if (heap[current_index] < temp) break;
heap[current_index / 2] = heap[current_index];
current_index *= 2;
}
heap[current_index / 2] = temp;
}
size = temp;
}
TopN问题
这个就更加简单了,维护一个N
就行了,然后删除n - 1
次根节点,或者将根节点和一个结点交换,然后size - 1
然后调整堆。
public int topN(int n) {
int res = 0;
for (int i = size; i > 0; --i) {
if (--n == 0) {
res = heap[1];
break;
}
removeMax();
}
return res;
}
leetcode第347题:前 K 个高频元素
给定一个非空的整数数组,返回其中出现频率前 k 高的元素。
示例 1:
输入: nums = [1,1,1,2,2,3], k = 2
输出: [1,2]
示例 2:
输入: nums = [1], k = 1
输出: [1]
说明:
你可以假设给定的 k 总是合理的,且 1 ≤ k ≤ 数组中不相同的元素的个数。
你的算法的时间复杂度必须优于 O(n log n) , n 是数组的大小。
思路,先用一个map
将他们相同的累加,然后遍历map
插入堆,然后就是TopN问题了。
实现代码(带上上面的有泛型的堆实现代码,也可以用Java
自带的优先队列PriorityQueue
):
import java.lang.reflect.Array;
import java.util.*;
public class Number_347 {
public List<Integer> topKFrequent(int[] nums, int k) {
HashMap<Integer, Integer> map = new HashMap<>();
for (int num : nums) {
map.put(num, map.getOrDefault(num, 0) + 1);
}
MaxHeap_347<Node> heap = new MaxHeap_347<>(Node.class);
for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
heap.insert(new Node(entry.getKey(), entry.getValue()));
}
List<Integer> res = new ArrayList<>();
for (int i = 0; i < k; ++i) {
res.add(heap.removeMax().key);
}
return res;
}
static class Node implements Comparable<Node> {
int key;
int value;
public Node(int key, int value) {
this.key = key;
this.value = value;
}
@Override
public int compareTo(Node o) {
return this.value - o.value;
}
}
public static void main(String[] args) {
Number_347 n = new Number_347();
n.topKFrequent(new int[]{3, 3, 1, 1, 2}, 3).forEach(System.out::println);
}
}