朴素Prim
朴素Prim算法和朴素Dijkstra算法的流程非常相似:
// 将所有点到集合的距离初始化为正无穷
dist[i] = 0x3f3f3f3f;
// n次迭代
for(int i = 0; i < n; i++){
1、找到不在集合当中的距离最小的点t(集合表示当前的连通块,即生成树)
2、用t来更新其他点到集合的距离(Dijkstra是用t来更新到起点的距离)
3、将t加入到集合中
}
最小生成树问题不存在环,所以正边和负边都是可以的。当不存在生成树时,说明所有点不连通。
AcWing 858. Prim算法求最小生成树
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 510, INF = 0x3f3f3f3f;
int g[N][N];
int dist[N];
bool st[N];
int n, m;
int prim(){
memset(dist, 0x3f, sizeof dist);
int res = 0;
for(int i = 0; i < n; i++)
{
int t = -1;
for(int j = 1; j <= n; j++)
if(!st[j] && (t == -1 || dist[t] > dist[j]))
t = j;
if(i && dist[t] == INF) return INF;
if(i) res += dist[t];
for(int j = 1; j <= n; j++)
dist[j] = min(dist[j], g[t][j]);
st[t] = true;
}
return res;
}
int main(){
cin >> n >> m;
memset(g, 0x3f, sizeof g);
while(m--)
{
int a, b, c;
cin >> a >> b >> c;
g[a][b] = g[b][a] = min(g[a][b], c);
}
int t = prim();
if(t == INF) puts("impossible");
else printf("%d\n", t);
return 0;
}
Kruskal
1、将所有边按权重从小到大排序 sort() --- O(mlogm);
2、枚举每条边a-b,权值为c
如果a b不连通,那么就将a-b加到集合中(并查集的应用) --- O(m * 1)
AcWing 859. Kruskal算法求最小生成树
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 200010, INF = 0x3f3f3f3f;
int n, m;
int p[N];
struct Edge{
int a, b, w;
// 重载 <
bool operator < (const Edge &W)const
{
return w < W.w;
}
}edges[N];
int find(int x){
if(p[x] != x) p[x] = find(p[x]);
return p[x];
}
int kruskal(){
sort(edges, edges + m);
for(int i = 1; i <= n; i++) p[i] = i;
int res = 0, cnt = 0;
for(int i = 0; i < m; i++)
{
int a = edges[i].a, b = edges[i].b, w = edges[i].w;
a = find(a), b = find(b);
if(a != b)
{
p[a] = b;
res += w;
cnt ++;
}
}
if(cnt < n - 1) return INF;
return res;
}
int main(){
cin >> n >> m;
for(int i = 0; i < m; i++)
{
int a, b, c;
cin >> a >> b >> c;
edges[i] = {a, b, c};
}
int t = kruskal();
if(t == INF) puts("impossible");
else printf("%d\n", t);
return 0;
}
我们应该要根据题给的范围(n和m的大小),来决定是使用朴素Prim算法(O(n^2))还是Kruskal算法(O(mlogm))。
AcWing 3728. 城市通电
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
#define x first
#define y second
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int N = 2010;
PII p[N];
int wc[N], wk[N], fa[N];
LL dist[N];
bool st[N];
vector<int> ans1;
vector<PII> ans2;
int n;
LL get_dist(int a, int b){
int dx = abs(p[a].x - p[b].x);
int dy = abs(p[a].y - p[b].y);
return (LL)(dx + dy) * (wk[a] + wk[b]);
}
LL prim(){
memset(dist, 0x3f, sizeof dist);
dist[0] = 0;
st[0] = true;
for(int i = 1; i <= n; i++) dist[i] = wc[i];
LL res = 0;
for(int i = 0; i < n; i++)
{
int t = -1;
for(int j = 1; j <= n; j++)
if(!st[j] && (t == -1 || dist[j] < dist[t]))
t = j;
res += dist[t];
st[t] = true;
if(!fa[t]) ans1.push_back(t);
else ans2.push_back({fa[t], t});
for(int j = 1; j <= n; j++)
{
if(dist[j] > get_dist(t, j))
{
dist[j] = get_dist(t, j);
fa[j] = t;
}
}
}
return res;
}
int main(){
cin >> n;
for(int i = 1; i <= n; i++)
scanf("%d %d", &p[i].x, &p[i].y);
for(int i = 1; i <= n; i++)
scanf("%d", &wc[i]);
for(int i = 1; i <= n; i++)
scanf("%d", &wk[i]);
LL t = prim();
printf("%lld\n", t);
printf("%d\n", (int)ans1.size());
for(int x : ans1)
printf("%d ", x);
printf("\n%d\n", (int)ans2.size());
for(auto& [x, y] : ans2)
printf("%d %d\n", x, y);
return 0;
}