题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=6567
题目大意:给出n-2条边形成两棵树,需要连一条边将它们变成一棵树使得任意点距离和最小,并输出这个距离和。
题目思路:
法一:由树的重心性质,重心满足所有点到该点的距离和最小,所以求出两棵树的重心,连起来跑个树上任意点距离和板子即可。
法二:树形DP。假设你不知道重心如此强大,也没有关系。直接按照它的要求,两棵树都要选择到其他所有点距离和最小的点连接起来,因为最终的答案是两棵树各自之间两两距离和+左边的树每个点到交界处的距离*右边树点的个数(每个点都要去一次右边的点,需要过去的次数就是右边树点的个数)+右边的树每个点到交界处的距离*左边树点的个数+左边点的个数*右边点的个数(连接的桥的贡献),可以发现唯一的变量就是两棵树每个点到交界处的距离,让他最小就行,所以我们需要得到每个点到其他任意点的距离和,最小的那家伙就是我们要找的。问题转换成如何找到这家伙。以左树为例(俩一样的),第一次dfs先得到一个sumson[x]和dpsum[x],此时这两个变量分别表示x的子树大小和所有孩子到x的距离和,然后就是精髓部分,就是dfs2,首先可以发现经过dfs以后dpsum[root]已经是根到所有点的距离,但是其他人都是到孩子的距离,所以就需要通过各自的爸爸更新。现在已经有了它跟它儿子的所有距离,所以它需要得到另一块的距离。首先对于根的儿子,dpsum[fa]是它爸到所有点的距离,减去dpsum[x],要注意此时两个dpsum表示的含义已经不同了,它爸已经被更新过了,表示的是它爸到所有点的距离,而它还仅仅只是它到所有孩子的距离,所以减去后得到的是它爸到所有除了它孩子以外的所有点的距离。它跟它爸之间这条边的贡献,随着主人公从它爸到它,悄然发生了变化,所以就需要进行更新,之前这条边的贡献是sumson[x],x这颗子树的大小,但是变成了它以后,就变成了它爸那块树的大小,所以就是-sumson[x]+num-sumson[x],具体可以看看代码。然后所有的dpsum[x]加起来就是每个点到其他任意点的和,/2就是任意距离和了,接着就没啥问题了。
以下是代码:
法一:
#include<bits/stdc++.h>
#include<unordered_map>
using namespace std;
#define rep(i,a,b) for(ll i=a;i<=b;i++)
#define per(i,a,b) for(ll i=a;i>=b;i--)
#define inf 0x3f3f3f3f
#define ll long long
const ll MAXN = 2e5+5;
ll n,u,v;
vector<ll>g[MAXN];
ll vis[MAXN],siz[MAXN],dp[MAXN],ans,pos,pos2,num;
queue<ll>q;
void dfs(ll x){
vis[x]=siz[x]=1;
ll max_part=0;
ll len=g[x].size();
rep(i,0,len-1){
ll y=g[x][i];
if(vis[y])continue;
dfs(y);
siz[x]+=siz[y];
max_part=max(max_part,siz[y]);
}
max_part=max(max_part,num-siz[x]);
if(max_part<ans){
ans=max_part;
pos=x;
}
}
void dfs2(ll x,ll fa){
siz[x]=1;
ll len=g[x].size();
rep(i,0,len-1){
ll y=g[x][i];
if(y==fa)continue;
dfs2(y,x);
siz[x]+=siz[y];
dp[x]+=dp[y]+siz[y]*(n-siz[y]);
}
}
int main()
{
while(~scanf("%lld",&n)){
if(n==2){
cout<<1<<endl;
continue;
}
memset(vis,0,sizeof(vis));
memset(dp,0,sizeof(dp));
while(!q.empty())q.pop();
rep(i,1,n)g[i].clear();
rep(i,1,n-2){
scanf("%lld%lld",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
q.push(1);
vis[1]=1;
num=0;
while(!q.empty()){
ll u=q.front();
q.pop();
num++;
ll len=g[u].size();
rep(i,0,len-1){
ll y=g[u][i];
if(vis[y])continue;
vis[y]=1;
q.push(y);
}
}
memset(vis,0,sizeof(vis));
ans=inf;
dfs(1);
pos2=pos;
num=n-num;
ans=inf;
rep(i,1,n){
if(!vis[i]){
dfs(i);
break;
}
}
g[pos].push_back(pos2);
g[pos2].push_back(pos);
dfs2(1,-1);
printf("%lld\n",dp[1]);
}
return 0;
}
法二:
#include<bits/stdc++.h>
using namespace std;
#define inf 0x3f3f3f3f
#define rep(i,a,b) for(ll i=a;i<=b;i++)
#define per(i,a,b) for(ll i=a;i>=b;i--)
#define ll long long
const ll MAXN = 1e5+5;
ll n,x,y,num1,num2,sumson[MAXN],dpsum[MAXN],pos,vis[MAXN];
vector<ll>v[MAXN];
void dfs(ll x,ll fa,ll flag){
sumson[x]=1,dpsum[x]=0,vis[x]=flag;
if(flag==1)num1++;
ll len=v[x].size();
rep(i,0,len-1){
ll y=v[x][i];
if(y==fa)continue;
dfs(y,x,flag);
sumson[x]+=sumson[y];
dpsum[x]+=dpsum[y]+sumson[y];
}
}
void dfs2(ll x,ll fa,ll num){
if(x!=1&&x!=pos){
dpsum[x]+=dpsum[fa]-dpsum[x]-sumson[x]+num-sumson[x];
}
ll len=v[x].size();
rep(i,0,len-1){
ll y=v[x][i];
if(y==fa)continue;
dfs2(y,x,num);
}
}
int main()
{
while(~scanf("%lld",&n)){
if(n==2){
cout<<1<<endl;
continue;
}
memset(vis,0,sizeof(vis));
rep(i,1,n)v[i].clear();
rep(i,1,n-2){
scanf("%lld%lld",&x,&y);
v[x].push_back(y);
v[y].push_back(x);
}
pos=num1=num2=0;
dfs(1,-1,1);
dfs2(1,-1,num1);
num2=n-num1;
rep(i,1,n){
if(!vis[i]){
pos=i;
break;
}
}
dfs(pos,-1,2);
dfs2(pos,-1,num2);
ll minn1=1ll<<62,minn2=1ll<<62,ans=0;
rep(i,1,n){
ans+=dpsum[i];
if(vis[i]==1){
if(minn1>dpsum[i]){
minn1=dpsum[i];
}
}
else{
if(minn2>dpsum[i]){
minn2=dpsum[i];
}
}
}
ans/=2;
ans=ans+minn1*num2+minn2*num1+num1*num2;
printf("%lld\n",ans);
}
return 0;
}