堆是一棵完全二叉树,何为完全二叉树?参见。
有了这个定义,我们就可以将堆分成两种,一个叫小根堆,另一个叫大根堆,顾名思义,小根堆就是堆最小,大根堆就是根最大。
堆的性质:
- 堆具有所有完全二叉树的性质
- 堆的左右子树也是一个堆。
这里我们引入一个例题:【模板】堆,这一道题目就是一道裸题,很容易想到使用堆,那么堆有什么基本操作?
#include<bits/stdc++.h>
using namespace std;
int a[100010],n;
void put(int d){
cnt++;a[cnt]=d;
int j=cnt;
while(j>1){
int fa=j>>1;
if(a[fa]<a[j])break;
int tmp=a[fa];a[fa]=a[j];a[j]=tmp;
j=fa;
}
}
int del(){
int res=a[1];
a[1]=a[cnt--];
int j=1;
while(j*2<=cnt){
int u=j<<1;
if(a[u]>a[u+1] && j*2+1<=cnt)u++;
if(a[j]<=a[u])break;
int tmp=a[j];a[j]=a[u];a[u]=tmp;
j=u;
}
return res;
}
int main(){
int i,j,k,m;
scanf("%d",&n);
for(i=1;i<=n;i++){
scanf("%d",&a[i]);
put(a[i]);
}
return 0;
}
堆的操作分为两种,一个是插入,另一个是删除:
插入:
void put(int d){
cnt++;a[cnt]=d;
int j=cnt;
while(j>1){
int fa=j>>1;
if(a[fa]<a[j])break;
int tmp=a[fa];a[fa]=a[j];a[j]=tmp;
j=fa;
}
}
思想:每一次与他的父亲节点比较,如果小于(建造小根堆),那么就交换,否则就退出。
删除:
int del(){
int res=a[1];
a[1]=a[cnt--];
int j=1;
while(j*2<=cnt){
int u=j<<1;
if(a[u]>a[u+1] && j*2+1<=cnt)u++;
if(a[j]<=a[u])break;
int tmp=a[j];a[j]=a[u];a[u]=tmp;
j=u;
}
return res;
}
思想:删除根结点后,把最后一个结点放到根结点,然后不断下沉,如果大于孩子,就取孩子中小的那个,直到不能下沉。
P.S:注意,这里返回的是根节点的值,因为有的题目不一定是单独做一个简单的操作,应用的话可能会扯到这个。
以下是【模板】堆的代码
#include<bits/stdc++.h>
using namespace std;
int a[1000010],cnt;
void put(int d){
cnt++;a[cnt]=d;
int j=cnt;
while(j>1){
int fa=j>>1;
if(a[fa]<a[j])break;
int tmp=a[fa];a[fa]=a[j];a[j]=tmp;
j=fa;
}
}
void del(){
int res=a[1];
a[1]=a[cnt--];
int j=1;
while(j*2<=cnt){
int u=j<<1;
if(a[u]>a[u+1] && j*2+1<=cnt)u++;
if(a[j]<=a[u])break;
int tmp=a[j];a[j]=a[u];a[u]=tmp;
j=u;
}
}
int main(){
int i,j,k,n,m;
scanf("%d",&n);
for(i=1;i<=n;i++){
int num;
scanf("%d",&num);
if(num==1){
scanf("%d",&m);
put(m);
}
if(num==2)printf("%d\n",a[1]);
if(num==3)del();
}
return 0;
}
下面有一道合并果子的题目,这题我附上代码,希望大家能够真正地理解。
#include<bits/stdc++.h>
using namespace std;
int a[100010],n,cnt;
void put(int d){
cnt++;a[cnt]=d;
int j=cnt;
while(j>1){
int fa=j>>1;
if(a[fa]<a[j])break;
int tmp=a[fa];a[fa]=a[j];a[j]=tmp;
j=fa;
}
}
int del(){
int res=a[1];
a[1]=a[cnt--];
int j=1;
while(j*2<=cnt){
int u=j<<1;
if(a[u]>a[u+1] && j*2+1<=cnt)u++;
if(a[j]<=a[u])break;
int tmp=a[j];a[j]=a[u];a[u]=tmp;
j=u;
}
return res;
}
int main(){
int i,j,k,m;
scanf("%d",&n);
for(i=1;i<=n;i++){
scanf("%d",&a[i]);
put(a[i]);
}
int ans=0;
for(i=1;i<n;i++){
int sum=del();
k=del();
put(sum+k);
ans+=sum+k;
}
printf("%d\n",ans);
return 0;
}