基于二叉堆的Prim算法
一. Prim 算法简介
向一个空集合中加入一个点,然后加入这个集合内所有点的相邻边中的最短边的相邻点,重复直至每一个点都进入集合
二. 二叉堆优化的意义
1.先来谈谈二叉堆
二叉堆是一种数据结构,保证它的子节点的数值全部都小于他们的父节点,所以可以保证堆顶的元素的值是最小的,这个和我们的需求不谋而合
2先来看看二叉堆部分的源码
// 这个是根据这一题设计的数据结构,emmm板子的话和这个有一点差别
struct heap
{
int num = 0;
p data[N*N];
void init()
{
for (int i = 0; i < N*N; ++i)
{
data[i].pay = inf;
}
}
// 初始化
inline void pop()
{
swap(data[1],data[num]);
data[num--].pay = inf;
fromTop(1);
}
// 弹出头节点,并且从上向下进行重新调整
inline p top()
{
return data[1];
}
// 返回头节点
inline void push(p x)
{
data[++num] = x;
fromBottom(num);
}
// 加入一个节点,并且从尾部向头节点进行调整
inline void fromBottom(int x)
{
while (x!=1)
{
if(data[fa(x)].pay>data[x].pay)
{
swap(data[fa(x)],data[x]);
x = fa(x);
}
else
return;
}
return;
}
// 从尾部向头进行重新排列
inline void fromTop(int x)
{
while (data[ls(x)].pay!=inf)
{
int target = (data[ls(x)].pay<data[rs(x)].pay)? ls(x): rs(x);
if(data[x].pay > data[target].pay)
{
swap(data[x], data[target]);
x = target;
}
else
return;
}
return;
}
// 从头部
}Heap;
三 基于二叉堆的优化
我们用堆去存储每一次加入的边,然后取出头部的那一条边emmm如果这一条边连的点已经被访问,那么继续取!
#include<iostream>
#include<algorithm>
#include<vector>
#define ls(x) (2*x)
#define rs(x) (2*x+1)
#define fa(x) (x/2)
#define N 600
#define inf 1e9
using namespace std;
int mark[N];
struct p
{
int to;
int pay;
}P;
vector<p> v[N];
vector<int>q;
struct heap
{
int num = 0;
p data[N*N];
void init()
{
for (int i = 0; i < N*N; ++i)
{
data[i].pay = inf;
}
}
inline void pop()
{
swap(data[1],data[num]);
data[num--].pay = inf;
fromTop(1);
}
inline p top()
{
return data[1];
}
inline void push(p x)
{
data[++num] = x;
fromBottom(num);
}
inline void fromBottom(int x)
{
while (x!=1)
{
if(data[fa(x)].pay>data[x].pay)
{
swap(data[fa(x)],data[x]);
x = fa(x);
}
else
return;
}
return;
}
inline void fromTop(int x)
{
while (data[ls(x)].pay!=inf)
{
int target = (data[ls(x)].pay<data[rs(x)].pay)? ls(x): rs(x);
if(data[x].pay > data[target].pay)
{
swap(data[x], data[target]);
x = target;
}
else
return;
}
return;
}
}Heap;
int cnt;
int ans;
int main()
{
int a,b;
Heap.init();
scanf("%d%d",&a,&b);
for (int i = 1; i <= b; ++i)
{
for (int j = 1; j <=b; ++j)
{
int x;
scanf("%d",&x);
if(x)
{
if(x>a) x =a;
P.to = j;
P.pay = x;
v[i].push_back(P);
}
}
}
int num = 0;
for (int i = 1; i <= b&&num<b; i++)
{
if(!mark[i])
{
q.push_back(i);
ans+=a;
// 如果这一点没有相连的点,那么只能加a
num++;
while(!q.empty())
{
int k = q.back();
mark[k] = 1;
for (int j = 0; j < v[k].size(); j++)
{
if(!mark[v[k][j].to])
{
Heap.push(v[k][j]);
}
// 如果to的点还没有被访问,那么我们加入这一条边到heap内
}
q.pop_back();
while (mark[Heap.top().to])
{
Heap.pop();
}
// 因为之前存的点可能有失效的部分,那么弹出
if(Heap.num)
{
ans+=Heap.top().pay;
num++;
if(num==b) goto doit;
q.push_back(Heap.top().to);
}
// 向集合中加入这个点
}
}
}
doit:
printf("%d\n", ans);
}