Prim算法
Prim算法是一种常见并且好写的最小生成树算法。该算法的基本思想是从一个节点开始,不断加点 (这里假设节点为1~n),适用于点少边多的图。
算法步骤:
1.用二维矩阵记录每两点之间的费用
2.用一个dist数组维护各点与点集的最短距离
3.随便选一个起点,这里选1号点,然后每次选一个距离点集最近的点加入集合
4.更新dist数组
朴素prim算法
1.邻接矩阵存图
#include <stdio.h>
const int N = 1010, INF = 1e9;
int dist[N];
int g[N][N];
bool st[N];
int n, m, minSum; //n个节点,m条边
void prim(int s){ //以s为起点
minSum = 0;
memset(st, false, sizeof(st));
for(int i = 1; i <= n; i++) dist[i] = g[s][i];
dist[s] = 0;
st[s] = true;
while(true){
int v = -1, mind = INF;
for(int i = 1; i <= n; i++){
if(!st[i] && dist[i] < mind){
mind = dist[i];
v = i;
}
}
if(v == -1) break;
st[v] = true;
minSum += mind;
for(int i = 1; i <= n; i++){
if(!st[i] && g[v][i] < dist[i]) dist[i] = g[v][i];
}
}
}
int main(void){
int a, b, c;
scanf("%d %d", &n, &m);
for(int i = 0; i < m; i++){
scanf("%d %d %d", &a, &b, &c);
g[a][b] = g[b][a] = c;
}
prim(1);
printf("%d\n", minSum);
return 0;
}
2.邻接表存图
#include <stdio.h>
#include <string.h>
const int N = 1e5, M = 2e6, INF = 1e9;
int dist[N];
int h[N], to[M], w[M], ne[M], idx;
bool st[N];
int n, m, res;
void add(int a, int b, int c){
to[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}
void prim(int s){
for(int i = h[s]; ~i; i = ne[i]) dist[to[i]] = w[i];
dist[s] = 0;
st[s] = true;
while(true){
int v = -1, mind = INF;
for(int i = 1; i <= n; i++){
if(!st[i] && dist[i] < mind){
mind = dist[i];
v = i;
}
}
if(v == -1) break;
st[v] = true;
res += mind;
for(int i = h[v]; ~i; i = ne[i]){
if(!st[to[i]] && dist[to[i]] > w[i]) dist[to[i]] = w[i];
}
}
}
int main(void){
int a, b, c;
scanf("%d %d", &n, &m);
for(int i = 0; i < m; i++){
scanf("%d %d %d", &a, &b, &c);
add(a, b, c), add(b, a, c);
}
res = 0;
prim(1);
printf("%d", res);
return 0;
}
堆优化的prim算法
#include <stdio.h>
#include <string.h>
const int N = 1e5, M = 2e6, INF = 1e9;
int dist[N];
int h[N], to[M], w[M], ne[M], idx;
int heap[M][2];
int hsize;
bool st[N];
int minv, mind;
int n, res;
void push(int v, int d){
int i;
for(i = ++hsize; heap[i / 2][0] > d; i /= 2){
heap[i][0] = heap[i / 2][0];
heap[i][1] = heap[i / 2][1];
}
heap[i][0] = d;
heap[i][1] = v;
}
void pop(){
mind = heap[1][0];
minv = heap[1][1];
int lastd = heap[hsize][0];
int lastv = heap[hsize--][1];
int i, child;
for(i = 1; i * 2 <= hsize; i = child){
child = i * 2;
if(child != hsize && heap[child + 1][0] < heap[child][0]) child++;
if(lastd > heap[child][0]){
heap[i][0] = heap[child][0];
heap[i][1] = heap[child][1];
}
else break;
}
heap[i][0] = lastd;
heap[i][1] = lastv;
}
void add(int a, int b, int c){
to[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}
void prim(int s){
for(int i = h[s]; ~i; i = ne[i]){
dist[to[i]] = w[i];
push(to[i], dist[to[i]]);
}
st[s] = true;
while(hsize){
pop();
if(st[minv]) continue;
st[minv] = true;
res += mind;
for(int i = h[minv]; ~i; i = ne[i]){
if(!st[i] && w[i] < dist[to[i]]){
dist[to[i]] = w[i];
push(to[i], dist[to[i]]);
}
}
}
}
int main(void){
int a, b, c;
scanf("%d %d", &n, &m);
for(int i = 0; i < m; i++){
scanf("%d %d %d", &a, &b, &c);
add(a, b, c), add(b, a, c);
}
res = 0;
prim(1);
printf("%d", res);
return 0;
}