Prim算法
1.概览
普里姆算法(Prim算法),图论中的一种算法,可在加权连通图里搜索最小生成树。意即由此算法搜索到的边子集所构成的树中,不但包括了连通图里的所有顶点(英语:Vertex (graph theory)),且其所有边的权值之和亦为最小。该算法于1930年由捷克数学家沃伊捷赫·亚尔尼克(英语:Vojtěch Jarník)发现;并在1957年由美国计算机科学家罗伯特·普里姆(英语:Robert C. Prim)独立发现;1959年,艾兹格·迪科斯彻再次发现了该算法。因此,在某些场合,普里姆算法又被称为DJP算法、亚尔尼克算法或普里姆-亚尔尼克算法。
2.算法简单描述
1).输入:一个加权连通图,其中顶点集合为V,边集合为E;
2).初始化:Vnew = {x},其中x为集合V中的任一节点(起始点),Enew = {},为空;
3).重复下列操作,直到Vnew = V:
a.在集合E中选取权值最小的边<u, v>,其中u为集合Vnew中的元素,而v不在Vnew集合当中,并且v∈V(如果存在有多条满足前述条件即具有相同权值的边,则可任意选取其中之一);
b.将v加入集合Vnew中,将<u, v>边加入集合Enew中;
4).输出:使用集合Vnew和Enew来描述所得到的最小生成树。
3.简单证明prim算法
反证法:假设prim生成的不是最小生成树
1).设prim生成的树为G0
2).假设存在Gmin使得cost(Gmin)<cost(G0) 则在Gmin中存在<u,v>不属于G0
3).将<u,v>加入G0中可得一个环,且<u,v>不是该环的最长边(这是因为<u,v>∈Gmin)
4).这与prim每次生成最短边矛盾
5).故假设不成立,命题得证.
4.代码
5.时间复杂度
这里记顶点数v,边数e
邻接矩阵:O(v2) 邻接表:O(elog2v)
Kruskal算法
1.概览
Kruskal算法是一种用来寻找最小生成树的算法,由Joseph Kruskal在1956年发表。用来解决同样问题的还有Prim算法和Boruvka算法等。三种算法都是贪婪算法的应用。和Boruvka算法不同的地方是,Kruskal算法在图中存在相同权值的边时也有效。
2.算法简单描述
1).记Graph中有v个顶点,e个边
2).新建图Graphnew,Graphnew中拥有原图中相同的e个顶点,但没有边
3).将原图Graph中所有e个边按权值从小到大排序
4).循环:从权值最小的边开始遍历每条边 直至图Graph中所有的节点都在同一个连通分量中
if 这条边连接的两个节点于图Graphnew中不在同一个连通分量中
添加这条边到图Graphnew中
3.简单证明Kruskal算法
对图的顶点数n做归纳,证明Kruskal算法对任意n阶图适用。
归纳基础:
n=1,显然能够找到最小生成树。
归纳过程:
假设Kruskal算法对n≤k阶图适用,那么,在k+1阶图G中,我们把最短边的两个端点a和b做一个合并操作,即把u与v合为一个点v',把原来接在u和v的边都接到v'上去,这样就能够得到一个k阶图G'(u,v的合并是k+1少一条边),G'最小生成树T'可以用Kruskal算法得到。
我们证明T'+{<u,v>}是G的最小生成树。
用反证法,如果T'+{<u,v>}不是最小生成树,最小生成树是T,即W(T)<W(T'+{<u,v>})。显然T应该包含<u,v>,否则,可以用<u,v>加入到T中,形成一个环,删除环上原有的任意一条边,形成一棵更小权值的生成树。而T-{<u,v>},是G'的生成树。所以W(T-{<u,v>})<=W(T'),也就是W(T)<=W(T')+W(<u,v>)=W(T'+{<u,v>}),产生了矛盾。于是假设不成立,T'+{<u,v>}是G的最小生成树,Kruskal算法对k+1阶图也适用。
由数学归纳法,Kruskal算法得证。
4.代码算法实现
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<algorithm>
using namespace std;
const int maxn=405;
int n,tot;
bool G[maxn][maxn];
int w[maxn],fa[maxn];
struct node{
int u,v,w;
}e[maxn*maxn];
bool cmp(node a,node b){return a.w<b.w;}
int find(int x){//并查集
if(fa[x]==x)return x;
else return fa[x]=find(fa[x]);
}
int main(){
int i,j,p,ans=0;
scanf("%d",&n);
for(i=1;i<=n;i++){
scanf("%d",&w[i]);
fa[i]=i;
}
for(i=1;i<=n;i++){
for(j=1;j<=n;j++){
scanf("%d",&p);
if(i!=j&&p!=0){
e[++tot].u=i;e[tot].v=j;e[tot].w=p;//用边集数组存边
}
}
}
sort(e+1,e+1+tot,cmp);
for(i=1;i<=tot;i++){
int x=find(e[i].u);
int y=find(e[i].v);
if(x!=y){//当两个点不属于一个并查集,就加入边集中
fa[x]=y;
ans+=e[i].w;
}
}
printf("%d",ans);
return 0;
}
次小生成树:
我一般写次小生成树是用两种方法:
1.用Prim先求最小生成树,在求最小生成树时记录每两点之间的最大值(顶点数,边数较小)
求次小生成树的过程就是每次给原先的最小生成树加边,因为最小生成树已经是联通图,所以每次加边必然会形成环,只要删去环上的最大边,就会得到新的生成树,每次统计,更新答案即可
代码:
#include<iostream>
#include<string>
#include<cstdio>
#include<map>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;
const int inf=1000000000;
const int maxn=1005;
int m[maxn][maxn],dist[maxn];
int path[maxn][maxn],pre[maxn];
bool used[maxn][maxn],vis[maxn];
int N,M;
int Prim(){
int sum=0,i,j;
memset(vis,0,sizeof(vis));
memset(used,0,sizeof(used));
memset(path,0,sizeof(path));
vis[1]=1;
for(i=1;i<=N;i++){
dist[i]=m[1][i];
pre[i]=1;
}
for(i=1;i<N;i++)
{
int u=-1;
for(j=1; j<=N;j++)
{
if(!vis[j])
{
if(u==-1||dist[j]<dist[u])
u=j;
}
}
used[u][pre[u]]=used[pre[u]][u]=true;
sum+=m[pre[u]][u];
vis[u]=1;
for(int j=1; j<=N;j++)
{
if(vis[j]&&j!=u)
{
path[u][j]=path[j][u]=max(path[j][pre[u]],dist[u]);
}
if(!vis[j])
{
if(dist[j]>m[u][j])
{
dist[j]=m[u][j];
pre[j]=u;
}
}
}
}
return sum;
}
int main()
{
int t,i,j;
scanf("%d",&t);
while(t--){
scanf("%d%d",&N,&M);
for(i=0;i<=N;i++)
for(j=i+1;j<=N; j++)
m[i][j]=m[j][i]=inf;
int u,v,w;
for(i=1;i<=M;i++)
{
scanf("%d%d%d",&u,&v,&w);
m[u][v]=m[v][u]=w;
}
int Mst=Prim();
int res=inf;
for(i=1; i<=N;i++)
{
for(j=1; j<=N;j++)
if(i!=j)
{
if(!used[i][j])
res=min(res,Mst+m[i][j]-path[i][j]);
}
}
cout<<res;
}
return 0;
}
2.先做一次最小生成树,但是在给原先最小生成树加边之前用倍增lca预处理出每两点之间的最大值和次大值(适用于顶点数,边数较大的情况)
代码:
#include<iostream>
#include<string>
#include<cstdio>
#include<map>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<bits/stdc++.h>
#include<ctime>
using namespace std;
const int inf=1000000000;
const int maxn=100005;
const int maxm=300005;
int N,M,tot=0,ans=inf;
long long mst;
int head[maxn],fa[maxn],dep[maxn];
bool vis[maxm];
int anc[maxn][20],max1[maxn][20],max2[maxn][20];
struct node{int to,w,next;}e[maxm*2];
struct edge{int u,v,w;}q[maxm*2];
inline int read(){
int x=0;char ch=getchar();
while(ch<'0'||ch>'9'){ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x;
}
inline void addedge(int u,int v,int w){
e[++tot].to=v;e[tot].w=w;e[tot].next=head[u];head[u]=tot;
e[++tot].to=u;e[tot].w=w;e[tot].next=head[v];head[v]=tot;
}
inline int find(int x){return fa[x]==x?x:fa[x]=find(fa[x]);}
bool cmp(edge a,edge b){return a.w<b.w;}
inline int Min(int x,int y){if(x<y)return x;else return y;}
inline int Max(int x,int y){if(x>y)return x;else return y;}
inline void dfs(int x,int father){
register int i;
for(i=1;i<=17&&(1<<i)<=dep[x];i++){
anc[x][i] = anc[anc[x][i-1]][i-1];
int Next = anc[x][i-1];
max1[x][i] = Max(max1[x][i-1],max1[Next][i-1]);
if(max1[x][i-1] != max1[x][i]) max2[x][i] = Max(max2[x][i],max1[x][i-1]);
if(max1[Next][i-1] != max1[x][i]) max2[x][i] = Max(max2[x][i],max1[Next][i-1]);
if(max2[x][i-1] != max1[x][i]) max2[x][i] = Max(max2[x][i],max2[x][i-1]);
if(max2[Next][i-1] != max1[x][i]) max2[x][i] = Max(max2[x][i],max2[Next][i-1]);
}
for(i=head[x];i;i=e[i].next){
if(e[i].to!=father){
anc[e[i].to][0]=x;
max1[e[i].to][0]=e[i].w;
dep[e[i].to]=dep[x]+1;
dfs(e[i].to,x);
}
}
}
inline int lca(int x,int y){
register int i;
if(dep[x]<dep[y])swap(x,y);
int depth=dep[x]-dep[y];
for(i=0;i<=17&&(1<<i)<=depth;i++)
if((1<<i)&depth)
x=anc[x][i];
for(i=17;i>=0;i--){
if(anc[x][i]!=anc[y][i]){
x=anc[x][i];
y=anc[y][i];
}
}
if(x==y)return x;
else return anc[x][0];
}
inline void work(int x,int y,int W)
{
register int ans1=0,ans2=0,i;
if(dep[x]<dep[y])swap(x,y);
int depth=dep[x]-dep[y];
for(i=0;i<=17&&(1<<i)<=depth;i++){
if((1<<i)&depth){
if(max1[x][i]>ans1){
ans2=ans1;
ans1=max1[x][i];
}
ans2=Max(ans2,max2[x][i]);
x=anc[x][i];
}
}
if(W==ans1)ans=Min(ans,W-ans2);
else ans=Min(ans,W-ans1);
}
int main(){
register int i,j,count=0;
N=read();M=read();
register int u,v,w;
for(i=1;i<=N;i++)fa[i]=i;
for(i=1;i<=M;i++){q[i].u=read();q[i].v=read();q[i].w=read();}
sort(q+1,q+1+M,cmp);
for(i=1;i<=M;i++){
int x=find(q[i].u);
int y=find(q[i].v);
if(x!=y){
fa[x]=y;
vis[i]=1;
addedge(q[i].u,q[i].v,q[i].w);
mst+=q[i].w;
if(++count==N-1)break;
}
}
dfs(1,0);
for(i=1;i<=M;i++){
if(!vis[i]){
int U=q[i].u,V=q[i].v,W=q[i].w;
int LCA=lca(U,V);
work(U,LCA,W);
work(V,LCA,W);
}
}
printf("%lld\n",mst+ans);
return 0;
}