题目
题意:
给出一棵树( n<=5000 ),n-1条边均有权值,先要拆掉一条边,然后补上一条边,使得仍然是树,求所有两点对距离和的最小值。
解:
可能这题难度并不大,不过我做这题并不大顺利
由于看到n的数量级,复杂度不能超过
O(n2)
,一开始想不到什么巧妙算法,后来想到可以枚举边,如果枚举拆掉边+补上的边,时间复杂度已经远超O(n^2),后来一度想不出来。
最后发现,拆掉一条边后,树分为A、B两棵树,A、B树内点对的距离不变,变得只会是A中某点到B中某点的距离,我们只需要想办法让他们的总和最小即可,因为边AB的权值是个定的,容易想到在A中选一个到所有点距离之和最小的点即可、B中同样。
然后就成了找树内到所有点距离和最小的点即可,这是个经典问题,先要统计每个点子树内结点个数,然后就是计算所有点到子树内所有点的距离和,之后再计算所有点到整棵树所有点距离和的值。两次dfs即可。对于找到A、B中两点后,本来想再建棵树来求答案,结果发现这样很麻烦,最后发现可以有计算公式,直接根据刚才树形dp的结果和枚举边的权值计算出答案。在这些枚举中取最小答案即可。
之后写完数组开小re2次,然后发现ac,看了下网上的代码,发现枚举边进行dfs时,对这条边上的两点,互作为父节点,进行dfs即可,这样自然分为两棵树,写法比我写的简单。之前我是枚举边,然后标记边,之后这些边不能走…
代码已经做了修改:
代码:
#include<cstdio>
#include<string>
#include<cstring>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<vector>
using namespace std;
#define all(x) (x).begin(), (x).end()
#define for0(a, n) for (int (a) = 0; (a) < (n); (a)++)
#define for1(a, n) for (int (a) = 1; (a) <= (n); (a)++)
#define mes(a,x,s) memset(a,x,(s)*sizeof a[0])
#define mem(a,x) memset(a,x,sizeof a)
#define ysk(x) (1<<(x))
typedef long long ll;
typedef pair<int, int> pii;
const int INF =0x3f3f3f3f;
const int maxn= 5000 ;
int n,fir[maxn+10],tmpN;
int nedge,to[2*maxn+10],nex[2*maxn+10],dis[2*maxn+10];bool no[2*maxn+10];
ll ans,dp[maxn+10],num[maxn+10],allDis[maxn+10];
int p[2];
inline void add_edge(int x,int y,int w)
{
to[nedge]=y;dis[nedge]=w;nex[nedge]=fir[x];
fir[x]=nedge++;
}
void dfs(int x,int fa)
{
num[x]=0;dp[x]=0;
for(int i=fir[x];~i;i=nex[i])
{
int y=to[i];if(y==fa) continue;
int w=dis[i];
dfs(y,x);
num[x]+=num[y];
dp[x]+=dp[y]+(ll)num[y]*w;
}
num[x]++;
}
void update(int x,int k)
{
if(p[k]<0|| allDis[x]<allDis[p[k]] ) p[k]=x;
}
void dfs2(int x,int fa,int k)
{
for(int i=fir[x];~i;i=nex[i])
{
int y=to[i];if(y==fa) continue;
int w=dis[i];
allDis[y]=allDis[x]-(ll)w*num[y]+ (ll)(tmpN-num[y])*w;
update(y,k);
dfs2(y,x,k);
}
}
void solve()
{
ans=-1; int N[2];
for(int i=0;i<2*(n-1);i+=2 )
{
mem(p,-1);
int x=to[i],y=to[i^1],w=dis[i];
dfs(x,y);tmpN=N[0]=num[x];allDis[x]=dp[x];update(x,0);dfs2(x,y,0);
dfs(y,x);tmpN=N[1]=num[y];allDis[y]=dp[y];update(y,1);dfs2(y,x,1);
ll ret= N[1]*( allDis[p[0] ]+(ll)N[0]*w )+N[0]*allDis[p[1] ];
ll ret2=0;
for1(j,n) ret2+=allDis[j];
ret2/=2;
ret+=ret2;
if(ans<0||ret<ans) ans=ret;
}
cout<<ans<<endl;
}
int main()
{
std::ios::sync_with_stdio(false);
int x,y,w;
while(cin>>n)
{
mes(fir,-1,n+1);nedge=0;
for0(i,n-1)
{
cin>>x>>y>>w;
add_edge(x,y,w);
add_edge(y,x,w);
}
solve();
}
return 0;
}