适合入门选手使用的prim最小生成树算法
这段时间在学图论算法,刷了一些模板题,也看了一些大牛的代码,发现有一个问题是许多代码虽然看起来简洁却不好理解,还用了大量的STL内容。因此,我将一个较好理解的代码放上来,希望能帮到大家。
该代码的算法跟标准的算法区别不大,只是较好理解,建议先看完其他地方的算法描述再看下面的代码实现
时间复杂度
O(n2)
O
(
n
2
)
#include<iostream>
#include<cstring>
#define maxn 5005
#define INF 99999999
int n,m;
long long ans;
int g[maxn][maxn];//邻接矩阵存图
int key[maxn];//key[v]表示蓝点v与白点相连的最小边权
int used[maxn];//表示是蓝点还是白点
using namespace std;
long long prim(){
memset(key,0x7f,sizeof(key));//初始化为极大值
key[1]=0;
memset(used,0,sizeof(used));//初始化为0,表示未添加进最小生成树(即算法描述中的“蓝点”
for(int i=1;i<=n;i++){
int k=0;
for(int j=1;j<=n;j++){//找一个与白点相连的权值最小的蓝点k
//这里可以用堆来优化
if(used[j]==0&&key[j]<key[k])
k=j;
}
used[k]=1;//蓝点k加入最小生成树,标记为白点
for(int j=1;j<=n;j++){//修改与k相邻的所有白点
if(used[j]==0&&g[k][j]<key[j])
key[j]=g[k][j];
}
}
long long sum=0;
for(int u=1;u<=n;u++) sum+=key[u];//累加权值
return sum;
}
int main(){
int a,b,c;
cin>>n>>m;
for(int i=1;i<=n;i++){
for(int j=1;j<=n;j++){
if(i==j)g[i][j]=0;
else g[i][j]=INF;
}
}
for(int i=1;i<=m;i++){
cin>>a>>b>>c;
if(c<g[a][b])
{
g[a][b]=g[b][a]=c;
/*注意判重,如果两点之间有多条路径那么取权值最小的一条, 有些图论题就在这里设置陷阱*/
}
}
ans=prim();
if(ans>=10000000)cout<<"orz"<<endl;//如果不联通就输出orz
else cout<<ans;
return 0;
}
另外,该算法还可以通过堆优化将时间复杂度降为 O(nlog2n) O ( n log 2 n )
#include<iostream>
#include<cstring>
#include<queue>
#define maxn 5005
#define maxm 2*200000+5
using namespace std;
int n,m;
int head[maxn];
int key[maxn],used[maxn];
struct edge_table{
int to;
int next;
int value;
}edge[maxm];
int edge_cnt=0;
struct heap_node{
int id;
int value;
friend bool operator <(heap_node a,heap_node b){
return a.value>b.value;
}
};
void add_edge(int u,int v,int w){
edge[++edge_cnt].to=v;
edge[edge_cnt].value=w;
edge[edge_cnt].next=head[u];
head[u]=edge_cnt;
}
int prim(){
int ans=0,tot=0;
memset(key,0x7f,sizeof(key));
memset(used,0,sizeof(used));
priority_queue<heap_node>heap;
heap_node now,nex;
now.id=1;
now.value=key[1]=0;
heap.push(now);
while(!heap.empty()){
now=heap.top();
heap.pop();
int u=now.id;
if(now.value!=key[u]) continue;
used[u]=1;
ans+=key[u];
tot++;
for(int i=head[u];i;i=edge[i].next){
int v=edge[i].to;
if(used[v]==0&&key[v]>edge[i].value){
key[v]=edge[i].value;
nex.value=key[v];
nex.id=v;
heap.push(nex);
}
}
}
if(tot<n) ans=-1;
return ans;
}
int main(){
cin>>n>>m;
for(int i=1;i<=m;i++){
int a,b,c;
cin>>a>>b>>c;
add_edge(a,b,c);
add_edge(b,a,c);
}
int ans=prim();
if(ans==-1) cout<<"orz"<<endl;
else cout<<ans;
}