优先队列和堆的总结
- 堆的基础知识
- 堆实现(使用动态数组)
- 使用堆实现优先队列
LeetCode-347. Top K Frequent Elements
测试优先队列(Top K
问题)
堆的基础知识
堆的基本性质:
- 堆首先是一颗完全二叉树;
- 堆要满足的性质是或者每一个结点都大于孩子,或者都小于孩子结点的值;
- 最大堆: 所有结点都比自己的孩子结点大;
- 最小堆: 所有结点都比自己的孩子结点小;
堆用数组来表示的话,数组可以从1
位置开始,也可以从0
位置开始,一般使用从0
位置开始,和孩子的下标对应关系如下:
从1
开始的对应关系:
从0
开始的对应关系:
堆实现(使用动态数组)
这里使用动态数组作为堆中的存储结构,也可以不使用动态数组,但是不能扩容的数组效率相对会比较低,动态数组实现看这篇博客。
public class Array<E> {
private E[] data;
private int size;
public Array(int capacity){ // user assign size
data = (E[])new Object[capacity];
size = 0;
}
public Array(){
this(10); // default size
}
public Array(E[] arr){
data = (E[])new Object[arr.length];
for(int i = 0 ; i < arr.length ; i ++)
data[i] = arr[i];
size = arr.length;
}
public int getSize(){
return size;
}
public int getCapacity(){
return data.length;
}
public boolean isEmpty(){
return size == 0;
}
public void rangeCheck(int index){
if(index < 0 || index >= size)
throw new IllegalArgumentException("Index is Illegal!");
}
public void add(int index,E e){
if(index < 0 || index > size){
throw new IllegalArgumentException("Index is Illegal ! ");
}
if(size == data.length)
resize(data.length * 2);
for(int i = size - 1; i >= index; i--){
data[i+1] = data[i];
}
data[index] = e;
size++;
}
private void resize(int newCapacity) {
E[] newData = (E[])new Object[newCapacity];
for(int i = 0; i < size; i++) newData[i] = data[i];
data = newData;
}
public void addLast(E e){ //末尾添加
add(size,e);
}
public void addFirst(E e){ //头部添加
add(0,e);
}
public E get(int index){
rangeCheck(index);
return data[index];
}
public E getLast(){
return get(size-1);
}
public E getFirst(){
return get(0);
}
public void set(int index,E e){
rangeCheck(index);
data[index] = e;
}
public boolean contains(E e){
for(int i = 0; i < size; i++){
if(data[i].equals(e))return true;
}
return false;
}
public int find(E e){
for(int i = 0; i < size; i++){
if(data[i].equals(e))return i;
}
return -1;
}
public E remove(int index){ // remove data[index] and return the value
rangeCheck(index);
E res = data[index];
for(int i = index; i < size-1; i++){
data[i] = data[i+1];
}
size--;
data[size] = null;//loitering objects != memory leak
if(size == data.length / 4 && data.length/2 != 0)resize(data.length / 2); //防止复杂度的震荡
return res;
}
public E removeFirst(){
return remove(0);
}
public E removeLast(){
return remove(size-1);
}
public void removeElement(E e){ //only remove one(may repetition) and user not know whether is deleted.
int index = find(e);
if(index != -1){
remove(index);
}
}
// new method
public void swap(int i, int j){
if(i < 0 || i >= size || j < 0 || j >= size)
throw new IllegalArgumentException("Index is illegal.");
E t = data[i];
data[i] = data[j];
data[j] = t;
}
@Override
public String toString() {
StringBuilder res = new StringBuilder();
res.append(String.format("Array : size = %d, capacity = %d\n",size,data.length));
res.append("[");
for(int i = 0; i < size; i++){
res.append(data[i]);
if(i != size-1){
res.append(", ");
}
}
res.append("]");
return res.toString();
}
}
MaxHeap
的实现:
add()
操作的逻辑是先把元素加入到动态数组的最后,然后向上浮,注意其中的siftUp
操作就是将一个元素向上浮使得满足最大堆的操作;extractMax()
就是取出堆顶的元素并且返回,其中步骤是先交换堆顶和最后一个元素,然后将新的堆顶往下沉的操作;replace()
操作就是取出堆中的最大元素,并且替换成元素e
;
/**
* 使用动态数组实现堆
* @param <E>
*/
public class MaxHeap<E extends Comparable<E>> {
private Array<E> data; // 里面存放的是动态的数组
public MaxHeap(int capacity){
data = new Array<>(capacity);
}
public MaxHeap(){
data = new Array<>();
}
public MaxHeap(E[] arr){
data = new Array<>(arr);
for(int i = (arr.length-1-1)/2; i >= 0; i--)
siftDown(i);
}
public int size(){
return data.getSize();
}
public boolean isEmpty(){
return data.isEmpty();
}
public void add(E e){// 向堆中添加元素
data.addLast(e); //在数组的末尾添加
siftUp(data.getSize() - 1);
}
/**上浮操作 */
private void siftUp(int index) {
while(data.get(index).compareTo(data.get((index-1)/2)) > 0){ //比父亲大
data.swap(index,(index-1)/2);
index = (index-1)/2; //继续向上
}
}
/** 看堆中的最大元素*/
public E findMax(){
if(data.getSize() == 0)
throw new IllegalArgumentException("Can not findMax when heap is empty.");
return data.get(0);
}
public E extractMax(){ /** 取出堆中最大元素(堆顶) */
E ret = findMax();
data.swap(0, data.getSize() - 1); //堆顶和最后一个元素交换
data.removeLast(); //把最后一个元素删除 ,也就是把之前的堆顶移出堆
siftDown(0); //新的堆顶要下沉
return ret;
}
/**下沉操作*/
private void siftDown(int i) {
int L = 2 * i + 1; //左孩子
while(L < data.getSize()){ //不越界
int maxIdx = L + 1 < data.getSize() && data.get(L+1).compareTo(data.get(L)) > 0 ? L+1 : L;
maxIdx = data.get(i).compareTo(data.get(maxIdx)) > 0 ? i : maxIdx;
if(maxIdx == i) break; //自己就是最大的,不需要下沉
data.swap(i,maxIdx);
i = maxIdx;
L = 2 * i + 1;
}
}
public E replace(E e){ // 取出堆中的最大元素,并且替换成元素e 一个O(logn)
E ret = findMax();
data.set(0, e);
siftDown(0);
return ret;
}
}
- 其中一个很优化的地方就是
heapify
的过程,也就是在构造函数中的从(arr.length - 1 - 1) / 2
的过程开始建堆,这比我们一个个的添加到堆中效率要高,原理就是找到第一个非叶子结点开始调整:
使用堆实现优先队列
/**
* 优先队列的接口
*/
public interface Queue<E> {
int getSize();
boolean isEmpty();
void enqueue(E e);
E dequeue();
E getFront();
}
public class PriorityQueue<E extends Comparable<E>> implements Queue<E> {
private MaxHeap<E>maxHeap;
public PriorityQueue(){
maxHeap = new MaxHeap<>();
}
@Override
public int getSize() {
return maxHeap.size();
}
@Override
public boolean isEmpty() {
return maxHeap.isEmpty();
}
@Override
public void enqueue(E e) {
maxHeap.add(e);
}
@Override
public E dequeue() {
return maxHeap.extractMax(); //取出最大元素
}
@Override
public E getFront() {
return maxHeap.findMax();
}
}
LeetCode-347. Top K Frequent Elements测试优先队列(Top K问题)
题目链接
题目
解析
使用上面的优先队列我们就可以维护一个有K个数的优先队列,然后遍历Map中的每一个数,如果堆中还没有K个数,直接入堆,如果有了K个数,就比较堆顶(堆中频次较低的),如果当前元素比堆顶的频次高,就从堆顶中弹出堆顶,并且让这个元素进入堆顶。
private class Freq implements Comparable<Freq>{
public int e;
public int freq; //元素和频次
public Freq(int e, int freq) {
this.e = e;
this.freq = freq;
}
@Override
public int compareTo(Freq o) {//按照频次升序 频次低的在堆顶
return this.freq < o.freq ? 1 : (this.freq > o.freq ? -1 : 0);
}
}
//维护一个K个数的优先队列
public List<Integer> topKFrequent(int[] nums, int k) {
HashMap<Integer,Integer> map = new HashMap<>();
for(int num : nums){
if(map.containsKey(num)){
map.put(num,map.get(num) + 1);
}else {
map.put(num,1);
}
}
PriorityQueue<Freq>queue = new PriorityQueue<>();
//维护一个 前 k个频次的堆
for(int key : map.keySet()){
if(queue.getSize() < k){
queue.enqueue(new Freq(key,map.get(key)));
}else{
if(map.get(key) > queue.getFront().freq){
queue.dequeue();
queue.enqueue(new Freq(key,map.get(key)));
}
}
}
List<Integer>res = new ArrayList<>();
while(!queue.isEmpty()){
res.add(queue.dequeue().e);
}
return res;
}
使用Java中的优先队列,并且基于Comparator
接口中的compare()
方法比较,实现方法差不多:
private class Freq {
public int e;
public int freq; //元素和频次
public Freq(int e, int freq) {
this.e = e;
this.freq = freq;
}
}
private class FreqComparator implements Comparator<Freq>{
@Override
public int compare(Freq o1, Freq o2) {
return o1.freq - o2.freq;
}
}
//维护一个K个数的优先队列
public List<Integer> topKFrequent(int[] nums, int k) {
HashMap<Integer,Integer> map = new HashMap<>();
for(int num : nums){
if(map.containsKey(num)){
map.put(num,map.get(num) + 1);
}else {
map.put(num,1);
}
}
PriorityQueue<Freq>queue = new PriorityQueue<>(new FreqComparator());
//维护一个 前 k个频次的堆
for(int key : map.keySet()){
if(queue.size() < k){ //堆中个数<k,直接入堆
queue.add(new Freq(key,map.get(key)));
}else{
if(map.get(key) > queue.peek().freq){
queue.poll();
queue.add(new Freq(key,map.get(key)));
}
}
}
List<Integer>res = new ArrayList<>();
while(!queue.isEmpty()){
res.add(queue.poll().e);
}
return res;
}
再次优化,使用匿名内部类以及省去那个Freq
类(优先队列中只需要存储对应的值,因为map
中已经存储了频次),并且使用Lambda表达式:
//维护一个K个数的优先队列
public List<Integer> topKFrequent(int[] nums, int k) {
HashMap<Integer,Integer> map = new HashMap<>();
for(int num : nums){
if(map.containsKey(num)){
map.put(num,map.get(num) + 1);
}else {
map.put(num,1);
}
}
PriorityQueue<Integer> queue = new PriorityQueue<>((o1, o2) -> map.get(o1) - map.get(o2));
//维护一个 前 k个频次的堆
for(int key : map.keySet()){
if(queue.size() < k){ //堆中个数<k,直接入堆
queue.add(key);
}else{
if(map.get(key) > map.get(queue.peek())){
queue.poll();
queue.add(key);
}
}
}
List<Integer>res = new ArrayList<>();
while(!queue.isEmpty()){
res.add(queue.poll());
}
return res;
}
最后再贴上我们在LeetCode中测试优先队列和堆的程序(较长):
具体代码如下:
class Solution {
private class Array<E> {
private E[] data;
private int size;
public Array(int capacity){ // user assign size
data = (E[])new Object[capacity];
size = 0;
}
public Array(){
this(10); // default size
}
public Array(E[] arr){
data = (E[])new Object[arr.length];
for(int i = 0 ; i < arr.length ; i ++)
data[i] = arr[i];
size = arr.length;
}
public int getSize(){
return size;
}
public int getCapacity(){
return data.length;
}
public boolean isEmpty(){
return size == 0;
}
public void rangeCheck(int index){
if(index < 0 || index >= size)
throw new IllegalArgumentException("Index is Illegal!");
}
public void add(int index,E e){
if(index < 0 || index > size){
throw new IllegalArgumentException("Index is Illegal ! ");
}
if(size == data.length)
resize(data.length * 2);
for(int i = size - 1; i >= index; i--){
data[i+1] = data[i];
}
data[index] = e;
size++;
}
private void resize(int newCapacity) {
E[] newData = (E[])new Object[newCapacity];
for(int i = 0; i < size; i++) newData[i] = data[i];
data = newData;
}
public void addLast(E e){ //末尾添加
add(size,e);
}
public void addFirst(E e){ //头部添加
add(0,e);
}
public E get(int index){
rangeCheck(index);
return data[index];
}
public E getLast(){
return get(size-1);
}
public E getFirst(){
return get(0);
}
public void set(int index,E e){
rangeCheck(index);
data[index] = e;
}
public boolean contains(E e){
for(int i = 0; i < size; i++){
if(data[i].equals(e))return true;
}
return false;
}
public int find(E e){
for(int i = 0; i < size; i++){
if(data[i].equals(e))return i;
}
return -1;
}
public E remove(int index){ // remove data[index] and return the value
rangeCheck(index);
E res = data[index];
for(int i = index; i < size-1; i++){
data[i] = data[i+1];
}
size--;
data[size] = null;//loitering objects != memory leak
if(size == data.length / 4 && data.length/2 != 0)resize(data.length / 2); //防止复杂度的震荡
return res;
}
public E removeFirst(){
return remove(0);
}
public E removeLast(){
return remove(size-1);
}
public void removeElement(E e){ //only remove one(may repetition) and user not know whether is deleted.
int index = find(e);
if(index != -1){
remove(index);
}
}
// new method
public void swap(int i, int j){
if(i < 0 || i >= size || j < 0 || j >= size)
throw new IllegalArgumentException("Index is illegal.");
E t = data[i];
data[i] = data[j];
data[j] = t;
}
@Override
public String toString() {
StringBuilder res = new StringBuilder();
res.append(String.format("Array : size = %d, capacity = %d\n",size,data.length));
res.append("[");
for(int i = 0; i < size; i++){
res.append(data[i]);
if(i != size-1){
res.append(", ");
}
}
res.append("]");
return res.toString();
}
}
private class MaxHeap<E extends Comparable<E>> {
private Array<E> data; // 里面存放的是动态的数组
public MaxHeap(int capacity){
data = new Array<>(capacity);
}
public MaxHeap(){
data = new Array<>();
}
public MaxHeap(E[] arr){
data = new Array<>(arr);
for(int i = (arr.length-1-1)/2; i >= 0; i--)
siftDown(i);
}
public int size(){
return data.getSize();
}
public boolean isEmpty(){
return data.isEmpty();
}
public void add(E e){// 向堆中添加元素
data.addLast(e); //在数组的末尾添加
siftUp(data.getSize() - 1);
}
/**上浮操作 */
private void siftUp(int index) {
while(data.get(index).compareTo(data.get((index-1)/2)) > 0){ //比父亲大
data.swap(index,(index-1)/2);
index = (index-1)/2; //继续向上
}
}
/** 看堆中的最大元素*/
public E findMax(){
if(data.getSize() == 0)
throw new IllegalArgumentException("Can not findMax when heap is empty.");
return data.get(0);
}
public E extractMax(){ /** 取出堆中最大元素(堆顶) */
E ret = findMax();
data.swap(0, data.getSize() - 1); //堆顶和最后一个元素交换
data.removeLast(); //把最后一个元素删除 ,也就是把之前的堆顶移出堆
siftDown(0); //新的堆顶要下沉
return ret;
}
/**下沉操作*/
private void siftDown(int i) {
int L = 2 * i + 1; //左孩子
while(L < data.getSize()){ //不越界
int maxIdx = L + 1 < data.getSize() && data.get(L+1).compareTo(data.get(L)) > 0 ? L+1 : L;
maxIdx = data.get(i).compareTo(data.get(maxIdx)) > 0 ? i : maxIdx;
if(maxIdx == i) break; //自己就是最大的,不需要下沉
data.swap(i,maxIdx);
i = maxIdx;
L = 2 * i + 1;
}
}
public E replace(E e){ // 取出堆中的最大元素,并且替换成元素e 一个O(logn)
E ret = findMax();
data.set(0, e);
siftDown(0);
return ret;
}
}
private interface Queue<E> {
int getSize();
boolean isEmpty();
void enqueue(E e);
E dequeue();
E getFront();
}
private class MyPriorityQueue<E extends Comparable<E>> implements Queue<E> {
private MaxHeap<E>maxHeap;
public MyPriorityQueue(){
maxHeap = new MaxHeap<>();
}
@Override
public int getSize() {
return maxHeap.size();
}
@Override
public boolean isEmpty() {
return maxHeap.isEmpty();
}
@Override
public void enqueue(E e) {
maxHeap.add(e);
}
@Override
public E dequeue() {
return maxHeap.extractMax(); //取出最大元素
}
@Override
public E getFront() {
return maxHeap.findMax();
}
}
private class Freq implements Comparable<Freq>{
public int e;
public int freq; //元素和频次
public Freq(int e, int freq) {
this.e = e;
this.freq = freq;
}
@Override
public int compareTo(Freq o) {//按照频次升序 频次低的在堆顶
return this.freq < o.freq ? 1 : (this.freq > o.freq ? -1 : 0);
}
}
//维护一个K个数的优先队列
public List<Integer> topKFrequent(int[] nums, int k) {
HashMap<Integer,Integer> map = new HashMap<>();
for(int num : nums){
if(map.containsKey(num)){
map.put(num,map.get(num) + 1);
}else {
map.put(num,1);
}
}
MyPriorityQueue<Freq>queue = new MyPriorityQueue<>();
//维护一个 前 k个频次的堆
for(int key : map.keySet()){
if(queue.getSize() < k){
queue.enqueue(new Freq(key,map.get(key)));
}else{
if(map.get(key) > queue.getFront().freq){
queue.dequeue();
queue.enqueue(new Freq(key,map.get(key)));
}
}
}
List<Integer>res = new ArrayList<>();
while(!queue.isEmpty()){
res.add(queue.dequeue().e);
}
return res;
}
}
也可以将自己写的动态数组Array
换成JDK
的ArrayList
:
class Solution {
private class MaxHeap<E extends Comparable<E>> {
private ArrayList<E> data; // 里面存放的是动态的数组
public MaxHeap(int capacity){
data = new ArrayList<>(capacity);
}
public MaxHeap(){
data = new ArrayList<>();
}
public MaxHeap(E[] arr){
data = new ArrayList<>(arr.length);
for(int i = (arr.length-1-1)/2; i >= 0; i--)
siftDown(i);
}
public int size(){
return data.size();
}
public boolean isEmpty(){
return data.isEmpty();
}
public void add(E e){// 向堆中添加元素
data.add(data.size(),e);
siftUp(data.size() - 1);
}
/**上浮操作 */
private void siftUp(int index) {
while(data.get(index).compareTo(data.get((index-1)/2)) > 0){ //比父亲大
swap(index, (index-1)/2);
index = (index-1)/2; //继续向上
}
}
/** 看堆中的最大元素*/
public E findMax(){
if(data.size() == 0)
throw new IllegalArgumentException("Can not findMax when heap is empty.");
return data.get(0);
}
public E extractMax(){ /** 取出堆中最大元素(堆顶) */
E ret = findMax();
swap(0, data.size()-1);
data.remove(data.size()-1); //把最后一个元素删除 ,也就是把之前的堆顶移出堆
siftDown(0); //新的堆顶要下沉
return ret;
}
/**下沉操作*/
private void siftDown(int i) {
int L = 2 * i + 1; //左孩子
while(L < data.size()){ //不越界
int maxIdx = L + 1 < data.size() && data.get(L+1).compareTo(data.get(L)) > 0 ? L+1 : L;
maxIdx = data.get(i).compareTo(data.get(maxIdx)) > 0 ? i : maxIdx;
if(maxIdx == i) break; //自己就是最大的,不需要下沉
swap(i, maxIdx);
i = maxIdx;
L = 2 * i + 1;
}
}
public E replace(E e){ // 取出堆中的最大元素,并且替换成元素e 一个O(logn)
E ret = findMax();
data.set(0, e);
siftDown(0);
return ret;
}
public void swap(int i, int j){
E temp = data.get(i);
data.set(i, data.get(j));
data.set(j, temp);
}
}
private interface Queue<E> {
int getSize();
boolean isEmpty();
void enqueue(E e);
E dequeue();
E getFront();
}
private class MyPriorityQueue<E extends Comparable<E>> implements Queue<E> {
private MaxHeap<E>maxHeap;
public MyPriorityQueue(){
maxHeap = new MaxHeap<>();
}
@Override
public int getSize() {
return maxHeap.size();
}
@Override
public boolean isEmpty() {
return maxHeap.isEmpty();
}
@Override
public void enqueue(E e) {
maxHeap.add(e);
}
@Override
public E dequeue() {
return maxHeap.extractMax(); //取出最大元素
}
@Override
public E getFront() {
return maxHeap.findMax();
}
}
private class Freq implements Comparable<Freq>{
public int e;
public int freq; //元素和频次
public Freq(int e, int freq) {
this.e = e;
this.freq = freq;
}
@Override
public int compareTo(Freq o) {//按照频次升序 频次低的在堆顶
return this.freq < o.freq ? 1 : (this.freq > o.freq ? -1 : 0);
}
}
//维护一个K个数的优先队列
public List<Integer> topKFrequent(int[] nums, int k) {
HashMap<Integer,Integer> map = new HashMap<>();
for(int num : nums){
if(map.containsKey(num)){
map.put(num,map.get(num) + 1);
}else {
map.put(num,1);
}
}
MyPriorityQueue<Freq>queue = new MyPriorityQueue<>();
//维护一个 前 k个频次的堆
for(int key : map.keySet()){
if(queue.getSize() < k){
queue.enqueue(new Freq(key,map.get(key)));
}else{
if(map.get(key) > queue.getFront().freq){
queue.dequeue();
queue.enqueue(new Freq(key,map.get(key)));
}
}
}
List<Integer>res = new ArrayList<>();
while(!queue.isEmpty()){
res.add(queue.dequeue().e);
}
return res;
}
}