Java实现最大(小)堆以及堆排序、TopN问题

7 篇文章 5 订阅

Java实现最大(小)堆以及堆排序、TopN问题


Java实现堆

什么是堆,先来了解原理,再看如何实现。

  • 堆的定义:堆(Heap)是计算机科学中一类特殊的数据结构的统称。堆通常是一个可以被看做一棵完全二叉树的数组对象。

堆可以看成是一棵树,并且这棵树的子树也是堆。而最大堆就是父节点大于子节点,根节点是最大的节点;最小堆就是父亲节点小于子节点,根节点是最小的节点。

要注意堆和二叉查找树的区别,二叉查找树是左子节点比父节点小,右子节点比父节点大。

如图最大堆、最小堆:

最大堆最小堆

而我们知道二叉树有一条性质是:
当前节点下标与子节点下标存在着这样的关系:左子节点的下标等于父节点的两倍,右子节点的下标等于父节点的两倍加1,即left_son.index = current.index * 2right_son.index = current.index * 2 + 1

而这个二叉树又是一个完全二叉树,结合上面这个特性,那么利用数组实现堆是十分合适的。

下面步入正题,堆的构建:

堆的构建

给定一个数组{1,9,7,5,18,4},对应的二叉树如下:

二叉树

我们需要把它调整成最大堆,也就是堆的初始化过程:可以从上往下,也可以从下往上,但是从上往下每次都要将子树遍历完,重复遍历了很多节点。所以采用自底向上方法,即从最后一个有子节点的开始(因为它是最后一个节点的父节点,所以它的下标是size / 2),让它的两个子节点进行比较,找到大的值再和父节点的值进行比较。如果父结点的值小,则子结点和父结点进行交换,交换之后再递归往下比较。然后再一步步递归上去,直到根节点结束。

7是最后一个有子节点的节点,和左右儿子节点比较,然后和大的换位子,然后先往下递归重复这个步骤到比它小的节点,然后再向上递归重复这个步骤到根节点,这里7是最大的,不用换,然后递归到上一层,即17比较,具体过程如下:

构建过程

实现代码:

    /**
     * 构造函数
     */
    public MaxHeap(int[] num, int capacity) {
        this(capacity + 1);
        if (num.length > capacity) {
            try {
                throw new Exception("capacity不能小于数组的长度!");
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        for (int i = 1, len = num.length; i <= len; ++i) {
            ++size;
            heap[i] = num[i - 1];
        }
        adjustment(); // 调整堆
    }

    /**
     * 调整堆
     */
    private void adjustment() {
        for (int i = size / 2; i >= 1; --i) {
            int temp = heap[i]; // 当前操作的节点,从后往前,第一个操作的就是那个有子节点的节点
            int max_index = i * 2; // 当前节点的左儿子节点
            while (max_index <= size) { // 当前节点的左儿子节点有效,下标在1~size代表有效
                if (max_index < size && heap[max_index] < heap[max_index + 1])
                    ++max_index;
                if (temp > heap[max_index]) break;
                heap[max_index / 2] = heap[max_index];
                max_index *= 2;
            }
            heap[max_index / 2] = temp;
        }
    }

堆的插入

最大堆的插入: 将元素插入末尾,因为原本已经是最大堆了,所以只需层层向上比较即可,比如上图,我要插入10,只需要和7比较,然后和18比较即可。

插入10

实现代码:

    public void insert(int element) {
        if (size == 0) { // 堆为空,直接插入0位置
            heap[1] = element;
            ++size;
            return;
        } else if (size >= heap.length - 1) { // 容量已满,扩容
            resize();
        }
        int index = ++size;
        while (index != 1 && element > heap[index / 2]) {
            heap[index] = heap[index / 2];
            index /= 2;
        }
        heap[index] = element;
    }

    public void resize() {
        heap = Arrays.copyOf(heap, heap.length + heap.length / 2);
    }

堆的删除

删除根节点: 这就更简单了,直接将最后一个节点移到根节点,然后再调整堆,然后递归向下比较,因为根节点到叶子结点的某条路一定是有序的,调整即找个合适的位置将根节点放到那里。

实现代码:

    public int removeMax() {
        int max = heap[1];
        heap[1] = heap[size--]; // 可以不写这句,但size要--
//        adjustment();
        int temp = heap[1];
        int current_index = 1 * 2;
        while (current_index <= size) {
            if (current_index < size && heap[current_index] < heap[current_index + 1]) {
                current_index++;
            }
            if (heap[current_index] < temp) break;
            heap[current_index / 2] = heap[current_index];
            current_index *= 2;
        }
        heap[current_index / 2] = temp;
        return max;
    }

具体实现代码

详细实现代码(这个只适合存放int类型,往下面有另一种,带泛型的):

import java.util.Arrays;

public class MaxHeap {
    private int[] heap;
    private int size;
    private final static int DEFAULT_CAPACITY = 10;

    public MaxHeap() {
        this(DEFAULT_CAPACITY);
    }

    public MaxHeap(int capacity) {
        heap = new int[capacity + 1];
        size = 0;
    }

    public MaxHeap(int[] num) {
        this(num, num.length * 2);
    }

    public MaxHeap(int[] num, int capacity) {
        this(capacity + 1);
        if (num.length > capacity) {
            try {
                throw new Exception("capacity不能小于数组的长度!");
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        for (int i = 1, len = num.length; i <= len; ++i) {
            ++size;
            heap[i] = num[i - 1];
        }
        adjustment(); // 调整堆
    }

    /**
     * 调整堆
     */
    private void adjustment() {
        for (int i = size / 2; i >= 1; --i) {
            int temp = heap[i];
            int max_index = i * 2;
            while (max_index <= size) {
                if (max_index < size && heap[max_index] < heap[max_index + 1])
                    ++max_index;
                if (temp > heap[max_index]) break;
                heap[max_index / 2] = heap[max_index];
                max_index *= 2;
            }
            heap[max_index / 2] = temp;
        }
    }

    public void insert(int element) {
        if (size == 0) { // 堆为空,直接插入0位置
            heap[1] = element;
            ++size;
            return;
        } else if (size >= heap.length - 1) { // 容量已满,扩容
            resize();
        }
        int index = ++size;
        while (index != 1 && element > heap[index / 2]) {
            heap[index] = heap[index / 2];
            index /= 2;
        }
        heap[index] = element;
    }

    public void resize() {
        heap = Arrays.copyOf(heap, heap.length + heap.length / 2);
    }

    /**
     * 删除最大值
     *
     * @return
     */
    public int removeMax() {
        int max = heap[1];
        heap[1] = heap[size--];
//        adjustment();
        int temp = heap[1];
        int current_index = 1 * 2;
        while (current_index <= size) {
            if (current_index < size && heap[current_index] < heap[current_index + 1]) {
                current_index++;
            }
            if (heap[current_index] < temp) break;
            heap[current_index / 2] = heap[current_index];
            current_index *= 2;
        }
        heap[current_index / 2] = temp;
        return max;
    }

    public void print() {
        for (int i = 1; i <= size; ++i) {
            System.out.print(heap[i] + " ");
        }
    }

    public static void main(String[] args) {
        MaxHeap maxHeap = new MaxHeap();
        maxHeap.insert(12);
        maxHeap.insert(14);
        maxHeap.insert(11);
        maxHeap.insert(110);
        maxHeap.insert(411);
        maxHeap.insert(11);
        maxHeap.print();
        System.out.println();
        System.out.println("============");
        MaxHeap maxHeap1 = new MaxHeap(new int[]{12, 14, 11, 110, 411, 11});
        maxHeap1.removeMax();
        maxHeap1.print();
    }
}

带泛型的实现(更通用):

package heap;

import java.lang.reflect.Array;
import java.util.Arrays;

public class MaxHeap_T<T extends Comparable<T>> {
    private T[] heap;
    private int size;
    private final static int DEFAULT_CAPACITY = 10;

    public MaxHeap_T(Class<T> clazz) {
        this(clazz, DEFAULT_CAPACITY);
    }

    public MaxHeap_T(Class<T> clazz, int capacity) {
        heap = (T[]) Array.newInstance(clazz, capacity);
        size = 0;
    }

    public MaxHeap_T(Class<T> clazz, T[] num) {
        this(clazz, num, num.length * 2);
    }

    public MaxHeap_T(Class<T> clazz, T[] num, int capacity) {
        this(clazz, capacity + 1);
        if (num.length > capacity) {
            try {
                throw new Exception("capacity不能小于数组的长度!");
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        for (int i = 1, len = num.length; i <= len; ++i) {
            ++size;
            heap[i] = num[i - 1];
        }
        adjustment(); // 调整堆
    }

    /**
     * 调整堆
     */
    private void adjustment() {
        for (int i = size / 2; i >= 1; --i) {
            T temp = heap[i];
            int max_index = i * 2;
            while (max_index <= size) {
                if (max_index < size && heap[max_index].compareTo(heap[max_index + 1]) < 0)
                    ++max_index;
                if (temp.compareTo(heap[max_index]) > 0) break;
                heap[max_index / 2] = heap[max_index];
                max_index *= 2;
            }
            heap[max_index / 2] = temp;
        }
    }

    public void insert(T element) {
        if (size == 0) { // 堆为空,直接插入0位置
            heap[1] = element;
            ++size;
            return;
        } else if (size >= heap.length - 1) { // 容量已满,扩容
            resize();
        }
        int index = ++size;
        while (index != 1 && element.compareTo(heap[index / 2]) > 0) {
            heap[index] = heap[index / 2];
            index /= 2;
        }
        heap[index] = element;
    }

    public void resize() {
        heap = Arrays.copyOf(heap, heap.length + heap.length / 2);
    }

    /**
     * 删除最大值
     *
     * @return
     */
    public T removeMax() {
        T max = heap[1];
        heap[1] = heap[size--];
        //        adjustment();
        T temp = heap[1];
        int current_index = 1 * 2;
        while (current_index <= size) {
            if (current_index < size && heap[current_index].compareTo(heap[current_index + 1]) < 0) {
                current_index++;
            }
            if (heap[current_index].compareTo(temp) < 0) break;
            heap[current_index / 2] = heap[current_index];
            current_index *= 2;
        }
        heap[current_index / 2] = temp;
        return max;
    }

    public void print() {
        for (int i = 1; i <= size; ++i) {
            System.out.print(heap[i] + " ");
        }
    }

    public static void main(String[] args) {
        MaxHeap maxHeap = new MaxHeap();
        maxHeap.insert(12);
        maxHeap.insert(14);
        maxHeap.insert(11);
        maxHeap.insert(110);
        maxHeap.insert(411);
        maxHeap.insert(11);
        maxHeap.print();
        System.out.println();
        System.out.println("============");
        MaxHeap_T<Integer> maxHeap1 = new MaxHeap_T<>(Integer.class, new Integer[]{12, 14, 11, 110, 411, 11});
        maxHeap1.removeMax();
        maxHeap1.print();
    }
}

堆排序

了解了堆的原理,堆排序应该很简单了,我们只需要将最后一个节点与第一个节点交换即可,换完一个size1,然后调整堆,调整的时候和删除是一样的,因为每条路径都是有序的,然后循环这个步骤,直到换完。

实现代码(利用上面的代码):

    public void heapSort() {
        int temp = size; // 先把原大小存一下
        for (int i = size; i > 0; --i) {
            --size;
            int t = heap[i];
            heap[i] = heap[1];
            heap[1] = t;
            
            int temp = heap[1];
            int current_index = 1 * 2;
            while (current_index <= size) {
                if (current_index < size && heap[current_index] < heap[current_index + 1]) {
                    current_index++;
                }
                if (heap[current_index] < temp) break;
                heap[current_index / 2] = heap[current_index];
                current_index *= 2;
            }
            heap[current_index / 2] = temp;

        }
        size = temp;
    }

TopN问题

这个就更加简单了,维护一个N就行了,然后删除n - 1次根节点,或者将根节点和一个结点交换,然后size - 1然后调整堆。

   public int topN(int n) {
        int res = 0;
        for (int i = size; i > 0; --i) {
            if (--n == 0) {
                res = heap[1];
                break;
            }
            removeMax();
        }
        return res;
    }

leetcode第347题:前 K 个高频元素

给定一个非空的整数数组,返回其中出现频率前 k 高的元素。

示例 1:

输入: nums = [1,1,1,2,2,3], k = 2
输出: [1,2]
示例 2:

输入: nums = [1], k = 1
输出: [1]
说明:

你可以假设给定的 k 总是合理的,且 1 ≤ k ≤ 数组中不相同的元素的个数。
你的算法的时间复杂度必须优于 O(n log n) , n 是数组的大小。

思路,先用一个map将他们相同的累加,然后遍历map插入堆,然后就是TopN问题了。

实现代码(带上上面的有泛型的堆实现代码,也可以用Java自带的优先队列PriorityQueue):

import java.lang.reflect.Array;
import java.util.*;

public class Number_347 {
    public List<Integer> topKFrequent(int[] nums, int k) {
        HashMap<Integer, Integer> map = new HashMap<>();
        for (int num : nums) {
            map.put(num, map.getOrDefault(num, 0) + 1);
        }
        MaxHeap_347<Node> heap = new MaxHeap_347<>(Node.class);
        for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
            heap.insert(new Node(entry.getKey(), entry.getValue()));
        }
        List<Integer> res = new ArrayList<>();
        for (int i = 0; i < k; ++i) {
            res.add(heap.removeMax().key);
        }
        return res;
    }

    static class Node implements Comparable<Node> {
        int key;
        int value;

        public Node(int key, int value) {
            this.key = key;
            this.value = value;
        }

        @Override
        public int compareTo(Node o) {
            return this.value - o.value;
        }
    }

    public static void main(String[] args) {
        Number_347 n = new Number_347();
        n.topKFrequent(new int[]{3, 3, 1, 1, 2}, 3).forEach(System.out::println);
    }
}

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值