题目链接: http://acm.hdu.edu.cn/showproblem.php?pid=6686
题意:
你在一棵树上要选取两条互不相交的路径,假设两条路径的长度分别为 l 1 , l 2 l_{1},l_{2} l1,l2 ,问你有多少种不同的 p a i r pair pair 对数 < l 1 , l 2 > <l_{1},l_{2}> <l1,l2> 。
做法:
可以想象的是,如果我们要在树上找一条最长的路径的话,那么这个路径一定是树的直径。那么如果我们要找两条路并且希望这两条路尽可能的长,那就有两种可能的结果,第一个是把树的直径拆开,两边的点各自向儿子(非直径)结点延伸找一条最长的边。第二个情况是一条就是我们的直径,另一条就是某个非直径的结点的长链。
分别对应下图中的左边和右边的情况。
在知道了这样的情况之后,我们就需要dp来维护我们的两个值了,一个是经过该点的最长单链的长度
d
p
[
0
]
[
i
]
dp[0][i]
dp[0][i] , 另一个是经过该点的最长链
d
p
[
1
]
[
i
]
dp[1][i]
dp[1][i] (即这个点两个儿子的最长单链+自己)。在得到这两个dp后,我们在直径上进行移动,用变量
L
[
i
]
L[i]
L[i]和
R
[
i
]
R[i]
R[i]来表示在左边有
i
i
i 个的时候,右边最多会有
L
[
i
]
L[i]
L[i] 个结点。
代码
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=(int)a;i<=(int)b;i++)
#define per(i,a,b) for(int i=(int)a;i>=(int)b;i--)
using namespace std;
typedef long long ll;
const int maxn=110000;
const int maxm=220005;
int ans[maxn],fa[maxn],L[maxn],R[maxn];
int to[maxm],nex[maxm],cnt,head[maxn];
int dp[2][maxn],n,dis[maxn],fid,flag[maxn];
vector<int> link;
void add(int u,int v){
to[cnt]=v,nex[cnt]=head[u]; head[u]=cnt++;
}
void getp(int u,int f){
fa[u]=f;
for(int i=head[u];i!=-1;i=nex[i]){
int v=to[i];
if(v==f) continue;
dis[v]=dis[u]+1;
if(dis[v]>dis[fid]) fid=v;
getp(v,u);
}
}
void dfs(int u,int f){
//0代表经过该点的最长单链
//1代表经过该点的最长链
dp[0][u]=dp[1][u]=1;
for(int i=head[u];~i;i=nex[i]){
int v=to[i];
if(v==f||flag[v]) continue;
dfs(v,u);
dp[1][u]=max(dp[1][u],dp[0][u]+dp[0][v]);
dp[0][u]=max(dp[0][u],dp[0][v]+1);
}
}
int main(){
/* int size = 512 << 20; // 512MB
char *p = (char*)malloc(size) + size;
__asm__("movl %0, %%esp\n" :: "r"(p));*/
int t;scanf("%d",&t);
int ff=0;
while(t--){
cnt=0;
link.clear();
scanf("%d",&n);
rep(i,1,n) {
head[i]=-1,ans[i]=0;
flag[i]=0,fa[i]=0;
L[i]=R[i]=0;
}
rep(i,1,n-1) {
int x,y; scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
fid=1;//!!!!!!!!!!!!
dis[1]=0; getp(1,0);
int tmp=fid; fid=1;
dis[tmp]=0; getp(tmp,0);
for(int i=fid;i!=0;i=fa[i]) link.push_back(i),flag[i]=1;
int sz=link.size();
for(int i=0;i<sz;i++){
dfs(link[i],0);
}
for(int i=0;i<sz;i++){
int u=link[i];
L[i]=i+dp[0][u];
R[i]=sz-i-1+dp[0][u];
}
rep(i,1,sz-1) L[i]=max(L[i],L[i-1]);
per(i,sz-2,0) R[i]=max(R[i],R[i+1]);
rep(i,1,sz){
int l=L[i-1],r=R[i];
ans[l]=max(ans[l],r);
ans[r]=max(ans[r],l);
}
rep(i,1,n) {
if(!flag[i]) {
tmp=dp[1][i];
ans[sz]=max(ans[sz],tmp);
ans[tmp]=max(ans[tmp],sz);
}
}
ll fans=0;
per(i,n,1) ans[i-1]=max(ans[i-1],ans[i]);
rep(i,1,n) fans+=(ll)ans[i];
printf("%lld\n",fans);
}
return 0;
}