题意:
给出一个n个节点的图,以及任两个节点间的距离,求最小生成树中,必要边的数目。(必要边:表示这条边一定存在于树中)
范围:
0<n<=3000
分析:
给出的是一个稠密图,且用邻接矩阵给出,所以用Prim算法求出最小生成树,并且在求最小生成树的过程中,将这棵树用邻接表的形式储存下来。
然后题目应该是最小生成树唯一性判断的升级版,思路同最小生成树唯一性的判断。
将那些非树枝的边加上去,假设为边A,如果构成的环中有边B的权值等于A的值,那么B肯定不是必要的边。
就是根据这个思路,将多余的边(非树上的边)分别加到树上,判断即可。
这里的处理方法是用dp[u][v],表示在加上去的边中,覆盖了边u-v(覆盖的意思:构成的环中有边u-v),且权值最小的边。
可以知道如果dp[u][v]==dis[u][v],则u-v非必要边,否则是必要边。
具体代码:
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <ctime>
#include <climits>
#include <cmath>
#include <cassert>
#include <iostream>
#include <string>
#include <vector>
#include <set>
#include <map>
#include <list>
#include <queue>
#include <stack>
#include <deque>
#include <algorithm>
using namespace std;
typedef long long ll;
const double eps=1e-8;
const int maxn=3010;
const int inf=10000000;
int dis[maxn][maxn],dp[maxn][maxn],lowc[maxn];
int head[maxn],tol,pre[maxn];
bool vis[maxn],mp[maxn][maxn];
struct node
{
int next,to;
} edge[2*maxn];
void add(int u,int v)
{
edge[tol].to=v;
edge[tol].next=head[u];
head[u]=tol++;
}
int prim(int n)
{
memset(vis,0,sizeof(vis));
memset(mp,0,sizeof(mp));
memset(pre,0,sizeof(pre));
int ans=0;
vis[1]=1;
pre[1]=-1;
lowc[1]=0;
for(int i=2; i<=n; i++)lowc[i]=dis[1][i],pre[i]=1;
for(int i=2; i<=n; i++)
{
int minc=inf;
int p=-1;
for(int j=1; j<=n; j++)
if(!vis[j]&&minc>lowc[j])
{
minc=lowc[j];
p=j;
}
ans+=minc;
vis[p]=1;
mp[pre[p]][p]=mp[p][pre[p]]=1;
add(pre[p],p);
add(p,pre[p]);
for(int j=1; j<=n; j++)
if(!vis[j]&&lowc[j]>dis[p][j])
{
lowc[j]=dis[p][j];
pre[j]=p;
}
}
return ans;
}
int dfs(int cur,int u,int fa)
{
int res=inf;
for(int i=head[u]; i!=-1; i=edge[i].next)
{
int v=edge[i].to;
if(v==fa)continue;
int tmp=dfs(cur,v,u);
//dp[u][v]:u->v表示树上的边,dp[u][v]表示被非树上的边覆盖的最小的边权值是多少。
dp[u][v]=dp[v][u]=min(dp[u][v],tmp);
res=min(res,tmp);
}
if(fa!=cur)res=min(res,dis[cur][u]);
return res;
}
int main()
{
int T,i,j,k,m,n;
scanf("%d",&T);
while(T--)
{
cin>>n;
memset(head,-1,sizeof(head));
tol=0;
for(int i=1; i<=n; i++)for(int j=1; j<=n; j++)dis[i][j]=inf;
for(int i=1; i<n; i++)
{
for(int j=i+1; j<=n; j++)
{
int x;
scanf("%d",&x);
dis[i][j]=dis[j][i]=x;
}
}
int ANS=prim(n);
if(ANS>inf)
{
puts("0");
continue;
}
for(i=1; i<=n; i++)
for(j=1; j<=n; j++)
dp[i][j]=inf;
for(i=1; i<=n; i++)dfs(i,i,-1);
int ans=0;
for(i=1; i<=n; i++)
for(j=i+1; j<=n; j++)
if(mp[i][j])
{
int tt=ANS-dis[i][j]+dp[i][j];
if(tt>ANS)ans++;
}
cout<<ans<<endl;
}
return 0;
}